def scatter_add_tensor(ref, indices, updates, name=None):
"""
Adds sparse updates to a variable reference.
This operation outputs ref after the update is done. This makes it
easier to chain operations that need to use the reset value.
Duplicate indices: if multiple indices reference the same location,
their contributions add.
Requires updates.shape = indices.shape + ref.shape[1:].
:param ref: A Tensor. Must be one of the following types: float32,
float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128,
qint8, quint8, qint32, half.
:param indices: A Tensor. Must be one of the following types: int32,
int64. A tensor of indices into the first dimension of ref.
:param updates: A Tensor. Must have the same dtype as ref. A tensor of
updated values to add to ref
:param name: A name for the operation (optional).
:return: Same as ref. Returned as a convenience for operations that want
to use the updated values after the update is done.
"""
with tensorflow.name_scope(name, 'scatter_add_tensor', [ref, indices, updates]) as scope:
ref = tensorflow.convert_to_tensor(ref, name='ref')
indices = tensorflow.convert_to_tensor(indices, name='indices')
updates = tensorflow.convert_to_tensor(updates, name='updates')
ref_shape = tensorflow.shape(ref, out_type=indices.dtype, name='ref_shape')
scattered_updates = tensorflow.scatter_nd(indices, updates, ref_shape, name='scattered_updates')
with tensorflow.control_dependencies([tensorflow.assert_equal(ref_shape, tensorflow.shape(scattered_updates, out_type=indices.dtype))]):
output = tensorflow.add(ref, scattered_updates, name=scope)
return output
评论列表
文章目录