network.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:dist-dqn 作者: viswanathgs 项目源码 文件源码
def _init_loss(cls, config, q, expected_q, actions, reg_loss=None,
                 summaries=None):
    """
    Setup the loss function and apply regularization is provided.

    @return: loss_op
    """
    q_masked = tf.reduce_sum(tf.mul(q, actions), reduction_indices=[1])
    loss = tf.reduce_mean(tf.squared_difference(q_masked, expected_q))
    if reg_loss is not None:
      loss += config.reg_param * reg_loss

    if summaries is not None:
      summaries.append(tf.scalar_summary('loss', loss))

    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号