ram.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tensorflow-ram 作者: qingzew 项目源码 文件源码
def grad_supervised(self, prob, labels):
        """
        return:
            loss = 1 / M * sum_i_{1..M} cross_entroy_loss(groundtruth, a_T)
            grads = grad(loss, params)
        inputs:
            prob
            labels = (n_batch,)
            [tensor variable]
        """
        labels = tf.cast(labels, tf.int64)
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(prob, labels, name = 'cross_entropy_per_example')
        loss = tf.reduce_mean(cross_entropy, name = 'cross_entropy')
        tvars = tf.trainable_variables()
        grads = tf.gradients(loss, tvars)
        for i in xrange(len(grads)):
            if grads[i] == None:
                grads[i] = tf.zeros(shape = tvars[i].get_shape())

        return loss, grads
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号