graph_definition.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:skiprnn-2017-telecombcn 作者: imatge-upc 项目源码 文件源码
def compute_gradients(loss, learning_rate, gradient_clipping=-1):
    """
    Create optimizer, compute gradients and (optionally) apply gradient clipping
    """
    opt = tf.train.AdamOptimizer(learning_rate)
    if gradient_clipping > 0:
        vars_to_optimize = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(loss, vars_to_optimize), clip_norm=gradient_clipping)
        grads_and_vars = list(zip(grads, vars_to_optimize))
    else:
        grads_and_vars = opt.compute_gradients(loss)
    return opt, grads_and_vars
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号