def train_discriminator(nsteps):
mean_loss = 0.0
for i in range(1,nsteps):
# pick real samples
batch_indeces = np.random.randint(0,O_train.shape[0],args.batch_size)
y_real = Y_train[batch_indeces,:,:,:]
# pick fake samples
batch_indeces = np.random.randint(0,O_train.shape[0],args.batch_size)
o_in = O_train[batch_indeces,:,:,:]
t_in = T_train[batch_indeces,:,:,:]
y_in = Y_train[batch_indeces,:,:,:]
y_fake = generator.predict([o_in, t_in, y_in])[0]
# train
y_disc = np.vstack([y_real, y_fake])
r = adversary.fit(y_disc, d_disc,
#callbacks=[TensorBoard(log_dir=args.tblog + '_D', write_graph=False)],
verbose=0)
loss = r.history['loss'][0]
mean_loss = mean_loss + loss
return mean_loss / nsteps
评论列表
文章目录