def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Accumulates gradients."""
grad_add_ops = []
if self._count <= self.num_passes - 1:
for grad, var in grads_and_vars:
if grad is not None:
_grad_cache = self.grad_cache[var]
if self._method == "cumsum":
_div = tf.div(grad, self.num_passes)
_add_op = _grad_cache.assign_add(_div)
grad_add_ops.append(_add_op)
else:
_add = tf.expand_dims(grad, 0)
_assign_op = tf.scatter_update(_grad_cache, [self._count], _add)
grad_add_ops.append(_assign_op)
else:
if v not in self._grad_cache:
self._grad_cache[var] = None
else:
raise Exception("You cannot call more apply_graidents")
grad_add_op = tf.group(*grad_add_ops)
if self._count < self.num_passes - 1:
final_op = grad_add_op
else:
zero_out_ops = []
with tf.control_dependencies([grad_add_op]):
if self._method == "cumsum":
grad_avg = [(tf.identity(gg), var)
for var, gg in self._grad_cache.items()]
else:
grad_avg = [(tf.reduce_mean(gg, [0]), var)
for var, gg in self._grad_cache.items()]
# Update the weight variables.
with tf.control_dependencies([grad_add_op]):
weight_update = self.opt.apply_gradients(
grad_avg, global_step=global_step, name=name)
# Zero out gradient cache.
with tf.control_dependencies([weight_update]):
for grad, var in grad_avg:
_grad_cache = self._grad_cache[var]
if _grad_cache is not None:
_grad_shape = _grad_cache.get_shape()
_zeros = tf.zeros(_grad_shape, dtype=_grad_cache.dtype)
_zero_out_grad = _grad_cache.assign(_zeros)
zero_out_ops.append(_zero_out_grad)
zero_out_op = tf.group(*zero_out_ops)
final_op = zero_out_op
self._count += 1
return final_op
python类scatter_update()的实例源码
def run(self, x, eta, idx_center=None, idx_sample=None):
""" x must be of size [B H W C] """
h = [None] * self.num_layer
embeddings = []
reg_ops = []
reset_ops = []
clustering_ops = []
with tf.variable_scope(self.scope):
for ii in xrange(self.num_layer):
if ii == 0:
input_vec = x
else:
input_vec = h[ii - 1]
h[ii] = tf.nn.conv2d(input_vec, self.w[ii], self.conv_filters[
'filter_stride'][ii], padding='SAME')
if self.add_bias:
h[ii] += self.b[ii]
if self.clustering_type[ii] == 'sample':
embedding = h[ii]
elif self.clustering_type[ii] == 'spatial':
embedding = h[ii]
elif self.clustering_type[ii] == 'channel':
embedding = tf.transpose(h[ii], [0, 3, 1, 2])
if self.clustering_shape[ii] is not None:
embedding = tf.reshape(
embedding, [-1, self.clustering_shape[ii][1]])
embeddings += [embedding]
clustering_ops += [kmeans_clustering(embedding, self.cluster_center[
ii], self.cluster_label[ii], self.num_cluster[ii], eta)]
sample_center = tf.stop_gradient(
tf.gather(self.cluster_center[ii], self.cluster_label[ii]))
reg_ops += [tf.reduce_mean(tf.square(embedding -
sample_center)) * self.alpha[ii] / 2.0]
reset_ops += [tf.scatter_update(self.cluster_center[ii], idx_center[
ii], tf.gather(embedding, idx_sample[ii]))]
if self.act_func[ii] is not None:
h[ii] = self.act_func[ii](h[ii])
if self.pool_func[ii] is not None:
h[ii] = self.pool_func[ii](h[ii], ksize=self.pooling['pool_size'][
ii], strides=self.pooling['pool_stride'][ii], padding='SAME')
return h, embeddings, clustering_ops, reg_ops, reset_ops