ops.py 文件源码

python
阅读 49 收藏 0 点赞 0 评论 0

项目:photinia 作者: XoriieInpottn 项目源码 文件源码
def clip_gradient(pair_list,
                  max_norm):
    """Perform gradient clipping.
    If the gradients' global norm exceed 'max_norm', then shrink it to 'max_norm'.

    :param pair_list: (grad, var) pair list.
    :param max_norm: The max global norm.
    :return: (grad, var) pair list, the original gradients' norm, the clipped gradients' norm
    """
    grad_list = [grad for grad, _ in pair_list]
    grad_list, raw_grad = tf.clip_by_global_norm(grad_list, max_norm)
    grad = tf.global_norm(grad_list)
    pair_list = [(grad, pair[1]) for grad, pair in zip(grad_list, pair_list)]
    return pair_list, raw_grad, grad
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号