model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:liveqa2017 作者: codekansas 项目源码 文件源码
def get_discriminator_op(self, r_preds, g_preds, d_weights):
        """Returns an op that updates the discriminator weights correctly.

        Args:
            r_preds: Tensor with shape (batch_size, num_timesteps, 1), the
                disciminator predictions for real data.
            g_preds: Tensor with shape (batch_size, num_timesteps, 1), the
                discriminator predictions for generated data.
            d_weights: a list of trainable tensors representing the weights
                associated with the discriminator model.

        Returns:
            dis_op, the op to run to train the discriminator.
        """

        with tf.variable_scope('loss/discriminator'):
            discriminator_opt = tf.train.AdamOptimizer(1e-3)

            eps = 1e-12
            r_loss = -tf.reduce_mean(tf.log(r_preds + eps))
            f_loss = -tf.reduce_mean(tf.log(1 + eps - g_preds))
            dis_loss = r_loss + f_loss
            # dis_loss = tf.reduce_mean(g_preds) - tf.reduce_mean(r_preds)

            # tf.summary.scalar('real', r_loss)
            # tf.summary.scalar('generated', f_loss)

            with tf.variable_scope('regularization'):
                dis_reg_loss = sum([tf.nn.l2_loss(w) for w in d_weights]) * 1e-6
            tf.summary.scalar('regularization', dis_reg_loss)

            total_loss = dis_loss + dis_reg_loss
            with tf.variable_scope('discriminator_update'):
                dis_op = self.get_updates(total_loss, discriminator_opt,
                                          d_weights)
            tf.summary.scalar('total', total_loss)

        return dis_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号