def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
check_inf_nan):
"""Calculate the average gradient for a shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
grad_and_vars: A list or tuple of (gradient, variable) tuples. Each
(gradient, variable) pair within the outer list represents the gradient
of the variable calculated for a single tower, and the number of pairs
equals the number of towers.
use_mean: if True, mean is taken, else sum of gradients is taken.
check_inf_nan: check grads for nans and infs.
Returns:
The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
gradient has been averaged across all towers. The variable is chosen from
the first tower. The has_nan_or_inf indicates the grads has nan or inf.
"""
grads = [g for g, _ in grad_and_vars]
grad = tf.add_n(grads)
if use_mean and len(grads) > 1:
grad = tf.multiply(grad, 1.0 / len(grads))
v = grad_and_vars[0][1]
if check_inf_nan:
has_nan_or_inf = tf.logical_not(tf.reduce_all(tf.is_finite(grads)))
return (grad, v), has_nan_or_inf
else:
return (grad, v), None
评论列表
文章目录