def clip_gradient(pair_list,
max_norm):
"""Perform gradient clipping.
If the gradients' global norm exceed 'max_norm', then shrink it to 'max_norm'.
:param pair_list: (grad, var) pair list.
:param max_norm: The max global norm.
:return: (grad, var) pair list, the original gradients' norm, the clipped gradients' norm
"""
grad_list = [grad for grad, _ in pair_list]
grad_list, raw_grad = tf.clip_by_global_norm(grad_list, max_norm)
grad = tf.global_norm(grad_list)
pair_list = [(grad, pair[1]) for grad, pair in zip(grad_list, pair_list)]
return pair_list, raw_grad, grad
评论列表
文章目录