nn.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tf_practice 作者: juho-lee 项目源码 文件源码
def get_train_op(loss,
        var_list=None,
        grad_clip=None,
        learning_rate=0.001,
        beta1=0.9,
        beta2=0.999):

    optimizer = tf.train.AdamOptimizer(
            learning_rate=learning_rate,
            beta1=beta1,
            beta2=beta2)
    if grad_clip is None:
        return optimizer.minimize(loss, var_list=var_list)
    else:
        gvs = optimizer.compute_gradients(loss, var_list=var_list)
        def clip(grad):
            if grad is None:
                return grad
            else:
                return tf.clip_by_value(grad, -grad_clip, grad_clip)
        capped_gvs = [(clip(grad), var) for grad, var in gvs]
        return optimizer.apply_gradients(capped_gvs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号