def combine_gradients(tower_grads):
"""Calculate the combined gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been summed
across all towers.
"""
filtered_grads = [[x for x in grad_list if x[0] is not None] for grad_list in tower_grads]
final_grads = []
for i in xrange(len(filtered_grads[0])):
grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))]
grad = tf.stack([x[0] for x in grads], 0)
grad = tf.reduce_sum(grad, 0)
final_grads.append((grad, filtered_grads[0][i][1],))
return final_grads
yt8m_utils.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录