def average_gradients(cls, tower_grads):
"""Average a list of (grads, vars) produced by `compute_gradients`."""
average_grads = []
for grads_and_vars in zip(*tower_grads):
# print(grads_and_vars)
grads = []
for g, _ in grads_and_vars:
# print(g.get_shape().as_list(), g)
grads.append(tf.expand_dims(g, axis=0))
grad = tf.concat(grads, axis=0)
grad = tf.reduce_mean(grad, axis=0)
# all variables are the same so we just use the first gpu variables
var = grads_and_vars[0][1]
grad_and_var = (grad, var)
average_grads.append(grad_and_var)
return average_grads
评论列表
文章目录