def average_gradients2(tower_grads):
"""This is identical to average_gradients() but returns pairs of (shared gradient, unshared variable) across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of Lists of pairs of (gradient, variable) where the gradient has been averaged
across all towers and variable is the one in each tower.
"""
res = []
mean_grads = average_gradients(tower_grads)
for grad_and_vars in tower_grads:
_grads = []
for _grad1, _grad2 in zip(mean_grads, grad_and_vars):
_grads.append( (_grad1[0],_grad2[1]) )
res.append(_grads)
return res
评论列表
文章目录