def aggregate_gradients_using_copy(tower_grads, use_mean, check_inf_nan):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over towers. The inner list is over individual gradients.
use_mean: if True, mean is taken, else sum of gradients is taken.
check_inf_nan: check grads for nans and infs.
Returns:
The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
gradient has been averaged across all towers. The variable is chosen from
the first tower. The has_nan_or_inf indicates the grads has nan or inf.
"""
agg_grads = []
has_nan_or_inf_list = []
for single_grads in zip(*tower_grads):
grad_and_var, has_nan_or_inf = aggregate_single_gradient_using_copy(
single_grads, use_mean, check_inf_nan)
agg_grads.append(grad_and_var)
has_nan_or_inf_list.append(has_nan_or_inf)
if check_inf_nan:
return agg_grads, tf.reduce_any(has_nan_or_inf_list)
else:
return agg_grads, None
评论列表
文章目录