def clip_gradients_by_global_norm(gradients_variables, clip_norm=20.):
"""Clips gradients of a multitask loss by their global norm.
Ignores all-zero tensors when computing the global norm.
Args:
gradients_variables: a list of pairs (gradient, variable).
clip_norm: a float Tensor, the global norm to clip on. Default is 20.0.
Returns:
list: A list of pairs of the same type as gradients_variables,.
fixed_global_norm: A 0-D (scalar) Tensor representing the global norm.
"""
gradients, variables = six.moves.zip(*gradients_variables)
def _replace_nonexisting_grad(grad):
if grad is None:
return grad
all_zeros = _is_all_zeros(grad)
return tf.cond(
all_zeros,
lambda: tf.zeros([], dtype=tf.as_dtype(grad.dtype)),
lambda: grad)
nonzero_gradients = [_replace_nonexisting_grad(g) for g in gradients]
fixed_global_norm = tf.global_norm(nonzero_gradients)
gradients, _ = tf.clip_by_global_norm(
gradients, clip_norm, use_norm=fixed_global_norm)
return list(six.moves.zip(gradients, variables)), fixed_global_norm
评论列表
文章目录