def apply_gradients(self, grads_and_vars, global_step=None):
"""Apply gradients to model variables specified in `grads_and_vars`.
`apply_gradients` returns an op that calls
`tf.train.Optimizer.apply_gradients` and then zeros the gradient
variables stored in `self.grads_and_vars`.
Args:
grads_and_vars (list): Description.
global_step (None, optional): tensorflow global_step variable.
Returns:
(tf.Operation): Applies gradient update to model followed by an
internal gradient zeroing operation to `self.grads_and_vars`.
"""
self.mini_flag = tf.assign(self.mini_flag, tf.constant([0], dtype = tf.float32))
# grads_and_vars = self.aggregate_gradients(grads_and_vars, method='average')
with tf.control_dependencies([self.mini_flag]):
optimize = self._optimizer.apply_gradients(grads_and_vars,
global_step=global_step)
#return [optimize, self.zero_grad()]
return optimize
评论列表
文章目录