def get_discriminator_op(self, r_preds, g_preds, d_weights):
"""Returns an op that updates the discriminator weights correctly.
Args:
r_preds: Tensor with shape (batch_size, num_timesteps, 1), the
disciminator predictions for real data.
g_preds: Tensor with shape (batch_size, num_timesteps, 1), the
discriminator predictions for generated data.
d_weights: a list of trainable tensors representing the weights
associated with the discriminator model.
Returns:
dis_op, the op to run to train the discriminator.
"""
with tf.variable_scope('loss/discriminator'):
discriminator_opt = tf.train.AdamOptimizer(1e-3)
eps = 1e-12
r_loss = -tf.reduce_mean(tf.log(r_preds + eps))
f_loss = -tf.reduce_mean(tf.log(1 + eps - g_preds))
dis_loss = r_loss + f_loss
# dis_loss = tf.reduce_mean(g_preds) - tf.reduce_mean(r_preds)
# tf.summary.scalar('real', r_loss)
# tf.summary.scalar('generated', f_loss)
with tf.variable_scope('regularization'):
dis_reg_loss = sum([tf.nn.l2_loss(w) for w in d_weights]) * 1e-6
tf.summary.scalar('regularization', dis_reg_loss)
total_loss = dis_loss + dis_reg_loss
with tf.variable_scope('discriminator_update'):
dis_op = self.get_updates(total_loss, discriminator_opt,
d_weights)
tf.summary.scalar('total', total_loss)
return dis_op
评论列表
文章目录