yt8m_utils.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:video_labelling_using_youtube8m 作者: LittleWat 项目源码 文件源码
def combine_gradients(tower_grads):
    """Calculate the combined gradient for each shared variable across all towers.

    Note that this function provides a synchronization point across all towers.

    Args:
      tower_grads: List of lists of (gradient, variable) tuples. The outer list
        is over individual gradients. The inner list is over the gradient
        calculation for each tower.
    Returns:
       List of pairs of (gradient, variable) where the gradient has been summed
       across all towers.
    """
    filtered_grads = [[x for x in grad_list if x[0] is not None] for grad_list in tower_grads]
    final_grads = []
    for i in xrange(len(filtered_grads[0])):
        grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))]
        grad = tf.stack([x[0] for x in grads], 0)
        grad = tf.reduce_sum(grad, 0)
        final_grads.append((grad, filtered_grads[0][i][1],))

    return final_grads
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号