resnet_cifar10_multi_gpu.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:dlbench 作者: hclhkbu 项目源码 文件源码
def average_gradients(tower_grads):
    """Calculate the average gradient for each shared variable across all towers.

    Note that this function provides a synchronization point across all towers.

    Args:
      tower_grads: List of lists of (gradient, variable) tuples. The outer list
        is over individual gradients. The inner list is over the gradient
        calculation for each tower.
    Returns:
       List of pairs of (gradient, variable) where the gradient has been averaged
       across all towers.
    """
    average_grads = []
    for single_grads in zip(*tower_grads):
        grads = [g for g, _ in single_grads]
        grad = tf.add_n(grads)
        grad = tf.multiply(grad, 1.0/len(grads))
        v = single_grads[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号