multi_pass_optimizer.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:resnet 作者: renmengye 项目源码 文件源码
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Accumulates gradients."""
    grad_add_ops = []
    if self._count <= self.num_passes - 1:
      for grad, var in grads_and_vars:
        if grad is not None:
          _grad_cache = self.grad_cache[var]

          if self._method == "cumsum":
            _div = tf.div(grad, self.num_passes)
            _add_op = _grad_cache.assign_add(_div)
            grad_add_ops.append(_add_op)
          else:
            _add = tf.expand_dims(grad, 0)
            _assign_op = tf.scatter_update(_grad_cache, [self._count], _add)
            grad_add_ops.append(_assign_op)
        else:
          if v not in self._grad_cache:
            self._grad_cache[var] = None
    else:
      raise Exception("You cannot call more apply_graidents")
    grad_add_op = tf.group(*grad_add_ops)
    if self._count < self.num_passes - 1:
      final_op = grad_add_op
    else:
      zero_out_ops = []
      with tf.control_dependencies([grad_add_op]):
        if self._method == "cumsum":
          grad_avg = [(tf.identity(gg), var)
                      for var, gg in self._grad_cache.items()]
        else:
          grad_avg = [(tf.reduce_mean(gg, [0]), var)
                      for var, gg in self._grad_cache.items()]

        # Update the weight variables.
        with tf.control_dependencies([grad_add_op]):
          weight_update = self.opt.apply_gradients(
              grad_avg, global_step=global_step, name=name)

        # Zero out gradient cache.
        with tf.control_dependencies([weight_update]):
          for grad, var in grad_avg:
            _grad_cache = self._grad_cache[var]
            if _grad_cache is not None:
              _grad_shape = _grad_cache.get_shape()
              _zeros = tf.zeros(_grad_shape, dtype=_grad_cache.dtype)
              _zero_out_grad = _grad_cache.assign(_zeros)
              zero_out_ops.append(_zero_out_grad)
      zero_out_op = tf.group(*zero_out_ops)
      final_op = zero_out_op
    self._count += 1
    return final_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号