def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Accumulates gradients."""
grad_add_ops = []
if self._count <= self.num_passes - 1:
for grad, var in grads_and_vars:
if grad is not None:
_grad_cache = self.grad_cache[var]
if self._method == "cumsum":
_div = tf.div(grad, self.num_passes)
_add_op = _grad_cache.assign_add(_div)
grad_add_ops.append(_add_op)
else:
_add = tf.expand_dims(grad, 0)
_assign_op = tf.scatter_update(_grad_cache, [self._count], _add)
grad_add_ops.append(_assign_op)
else:
if v not in self._grad_cache:
self._grad_cache[var] = None
else:
raise Exception("You cannot call more apply_graidents")
grad_add_op = tf.group(*grad_add_ops)
if self._count < self.num_passes - 1:
final_op = grad_add_op
else:
zero_out_ops = []
with tf.control_dependencies([grad_add_op]):
if self._method == "cumsum":
grad_avg = [(tf.identity(gg), var)
for var, gg in self._grad_cache.items()]
else:
grad_avg = [(tf.reduce_mean(gg, [0]), var)
for var, gg in self._grad_cache.items()]
# Update the weight variables.
with tf.control_dependencies([grad_add_op]):
weight_update = self.opt.apply_gradients(
grad_avg, global_step=global_step, name=name)
# Zero out gradient cache.
with tf.control_dependencies([weight_update]):
for grad, var in grad_avg:
_grad_cache = self._grad_cache[var]
if _grad_cache is not None:
_grad_shape = _grad_cache.get_shape()
_zeros = tf.zeros(_grad_shape, dtype=_grad_cache.dtype)
_zero_out_grad = _grad_cache.assign(_zeros)
zero_out_ops.append(_zero_out_grad)
zero_out_op = tf.group(*zero_out_ops)
final_op = zero_out_op
self._count += 1
return final_op
评论列表
文章目录