predictron.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:predictron 作者: brendanator 项目源码 文件源码
def loss(preturns, lambda_preturn, labels):
  with tf.variable_scope('loss'):
    preturns_loss = tf.reduce_mean(
        tf.squared_difference(preturns, tf.expand_dims(labels, 1)))

    lambda_preturn_loss = tf.reduce_mean(
        tf.squared_difference(lambda_preturn, labels))

    consistency_loss = tf.reduce_mean(
        tf.squared_difference(
            preturns, tf.stop_gradient(tf.expand_dims(lambda_preturn, 1))))

    l2_loss = tf.get_collection('losses')

    total_loss = preturns_loss + lambda_preturn_loss + consistency_loss
    consistency_loss += l2_loss
    return total_loss, consistency_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号