ram.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:tensorflow-ram 作者: qingzew 项目源码 文件源码
def grad(self, loc_mean_t, loc_t, h_t, prob, pred, labels):
        loss1, grads1 = self.grad_reinforcement(loc_mean_t, loc_t, h_t, prob, pred, labels)
        loss2, grads2 = self.grad_supervised(prob, labels)

        loss = (1 - self.lambda_) * loss1 + self.lambda_ * loss2

        grads = []
        for i in xrange(len(grads1)):
            grads.append((1 - self.lambda_) * grads1[i] + self.lambda_ * grads2[i])

        tvars = tf.trainable_variables()
        grads = zip(grads, tvars)

        tf.scalar_summary('loss', loss)
        tf.scalar_summary('loss_reinforcement', loss1)
        tf.scalar_summary('loss_supervised', loss2)

        return loss, grads
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号