def aggregate_gradients_using_copy_with_device_selection(
benchmark_cnn, tower_grads, use_mean, check_inf_nan):
"""Aggregate gradients, controlling device for the aggregation.
Args:
benchmark_cnn: benchmark_cnn class.
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over towers. The inner list is over individual gradients.
use_mean: if True, mean is taken, else sum of gradients is taken.
check_inf_nan: If true, check grads for nans and infs.
Returns:
The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
gradient has been averaged across all towers. The variable is chosen from
the first tower. The has_nan_or_inf indicates the grads has nan or inf.
"""
if benchmark_cnn.local_parameter_device_flag == 'gpu':
avail_devices = benchmark_cnn.raw_devices
else:
avail_devices = [benchmark_cnn.param_server_device]
agg_grads = []
has_nan_or_inf_list = []
for i, single_grads in enumerate(zip(*tower_grads)):
with tf.device(avail_devices[i % len(avail_devices)]):
grad_and_var, has_nan_or_inf = aggregate_single_gradient_using_copy(
single_grads, use_mean, check_inf_nan)
agg_grads.append(grad_and_var)
has_nan_or_inf_list.append(has_nan_or_inf)
if check_inf_nan:
return agg_grads, tf.reduce_any(has_nan_or_inf_list)
else:
return agg_grads, None
评论列表
文章目录