def _scatter_f_var(self, dst, src, mode="update"):
# create a temporary variable for dst so that we can use the sparse
# variable updates. despite this looking incredibly inefficient, it is
# actually faster than the scatter_nd approach
# from tensorflow.python.ops import gen_state_ops
# var = gen_state_ops._temporary_variable(
# self.bases[dst.key].get_shape(), self.bases[dst.key].dtype)
# var_name = var.op.name
# var = tf.assign(var, self.bases[dst.key])
var = self.bases[dst.key]
if (dst.as_slice is not None and
var.get_shape().is_compatible_with(src.get_shape()) and
dst.indices[0] == 0 and
dst.indices[-1] == var.get_shape()[0].value - 1 and
len(dst.indices) == var.get_shape()[0]):
if mode == "inc":
result = tf.assign_add(var, src, use_locking=False)
else:
result = tf.assign(var, src, use_locking=False)
elif mode == "inc":
result = tf.scatter_add(var, dst.tf_indices, src,
use_locking=False)
else:
result = tf.scatter_update(var, dst.tf_indices, src,
use_locking=False)
# result = gen_state_ops._destroy_temporary_variable(var, var_name)
return result
评论列表
文章目录