def _clip_sparse(self, grad, var):
assert isinstance(grad, tf.IndexedSlices)
clip_dims = self._vars_to_clip_dims[var]
if 0 in clip_dims:
log.warn("Clipping norm across dims %s for %s is inefficient "
"when including sparse dimension 0.", clip_dims,
var.op.name)
return self._clip_dense(var)
with tf.colocate_with(var):
var_subset = tf.gather(var, grad.indices)
with self._maybe_colocate_with(var):
normalized_var_subset = tf.clip_by_norm(
var_subset, self._max_norm, clip_dims)
delta = tf.IndexedSlices(
var_subset - normalized_var_subset, grad.indices, grad.dense_shape)
with tf.colocate_with(var):
return var.scatter_sub(delta, use_locking=self._use_locking)
评论列表
文章目录