def average_gradients(tower_grads):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = []
for single_grads in zip(*tower_grads):
grads = [g for g, _ in single_grads]
grad = tf.add_n(grads)
grad = tf.multiply(grad, 1.0/len(grads))
v = single_grads[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
评论列表
文章目录