variable_mgr.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:benchmarks 作者: tensorflow 项目源码 文件源码
def get_gradients_to_apply(self, device_num, gradient_state):
    device_grads = gradient_state  # From 2nd result of preprocess_device_grads.

    avg_grads, self.grad_has_inf_nan = (
        variable_mgr_util.aggregate_gradients_using_copy_with_device_selection(
            self.benchmark_cnn,
            device_grads,
            use_mean=True,
            check_inf_nan=self.benchmark_cnn.enable_auto_loss_scale))

    # Make shadow variable on a parameter server for each original trainable
    # variable.
    for i, (g, v) in enumerate(avg_grads):
      my_name = variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/' + v.name
      if my_name.endswith(':0'):
        my_name = my_name[:-2]
      new_v = tf.get_variable(
          my_name,
          dtype=v.dtype.base_dtype,
          initializer=v.initial_value,
          trainable=True)
      avg_grads[i] = (g, new_v)
    return avg_grads
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号