bingrad_common.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:terngrad 作者: wenwei202 项目源码 文件源码
def average_gradients2(tower_grads):
  """This is identical to average_gradients() but returns pairs of (shared gradient, unshared variable) across all towers.

  Note that this function provides a synchronization point across all towers.

  Args:
    tower_grads: List of lists of (gradient, variable) tuples. The outer list
      is over individual gradients. The inner list is over the gradient
      calculation for each tower.
  Returns:
     List of Lists of pairs of (gradient, variable) where the gradient has been averaged
     across all towers and variable is the one in each tower.
  """
  res = []
  mean_grads = average_gradients(tower_grads)
  for grad_and_vars in tower_grads:
      _grads = []
      for _grad1, _grad2 in zip(mean_grads, grad_and_vars):
          _grads.append( (_grad1[0],_grad2[1]) )
      res.append(_grads)

  return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号