def _numerically_stable_global_norm(tensor_list):
"""Compute the global norm of a list of Tensors, with improved stability.
The global norm computation sometimes overflows due to the intermediate L2
step. To avoid this, we divide by a cheap-to-compute max over the
matrix elements.
Args:
tensor_list: A list of tensors, or `None`.
Returns:
A scalar tensor with the global norm.
"""
if np.all([x is None for x in tensor_list]):
return 0.0
list_max = tf.reduce_max([tf.reduce_max(tf.abs(x)) for x in
tensor_list if x is not None])
return list_max * tf.global_norm([x / list_max for x in tensor_list
if x is not None])
评论列表
文章目录