def _average_gradients(tower_grads, include_square=False):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = []
average_grads_square = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads = []
none_count = 0
for g, v in grad_and_vars:
if g == None:
none_count = none_count + 1
continue
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
if none_count==0:
# Average over the 'tower' dimension.
grad_cat = tf.concat(0, grads)
grad = tf.reduce_mean(grad_cat, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
if include_square:
grad2 = tf.mul(grad_cat, grad_cat, name="square_gradient")
grad2 = tf.reduce_mean(grad2, 0)
average_grads_square.append((grad2, v))
elif none_count == len(grad_and_vars):
print("None gradient for %s" % (grad_and_vars[0][1].op.name))
else:
raise ValueError("None gradient error")
if include_square:
return average_grads, average_grads_square
else:
return average_grads
评论列表
文章目录