train.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:BDD_Driving_Model 作者: gy20073 项目源码 文件源码
def _average_gradients(tower_grads, include_square=False):
  """Calculate the average gradient for each shared variable across all towers.

  Note that this function provides a synchronization point across all towers.

  Args:
    tower_grads: List of lists of (gradient, variable) tuples. The outer list
      is over individual gradients. The inner list is over the gradient
      calculation for each tower.
  Returns:
     List of pairs of (gradient, variable) where the gradient has been averaged
     across all towers.
  """
  average_grads = []
  average_grads_square = []
  for grad_and_vars in zip(*tower_grads):
    # Note that each grad_and_vars looks like the following:
    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
    grads = []

    none_count = 0
    for g, v in grad_and_vars:
      if g == None:
          none_count = none_count + 1
          continue
      # Add 0 dimension to the gradients to represent the tower.
      expanded_g = tf.expand_dims(g, 0)

      # Append on a 'tower' dimension which we will average over below.
      grads.append(expanded_g)

    if none_count==0:
        # Average over the 'tower' dimension.
        grad_cat = tf.concat(0, grads)
        grad = tf.reduce_mean(grad_cat, 0)

        # Keep in mind that the Variables are redundant because they are shared
        # across towers. So .. we will just return the first tower's pointer to
        # the Variable.
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)

        if include_square:
            grad2 = tf.mul(grad_cat, grad_cat, name="square_gradient")
            grad2 = tf.reduce_mean(grad2, 0)
            average_grads_square.append((grad2, v))

    elif none_count == len(grad_and_vars):
        print("None gradient for %s" % (grad_and_vars[0][1].op.name))
    else:
        raise ValueError("None gradient error")
  if include_square:
      return average_grads, average_grads_square
  else:
      return average_grads
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号