def gather(self, src, force_copy=False):
"""Fetches the data corresponding to ``src`` from the base array.
Parameters
----------
src : :class:`.TensorSignal`
Signal indicating the data to be read from base array
force_copy : bool, optional
If True, always perform a gather, not a slice (this forces a
copy). Note that setting ``force_copy=False`` does not guarantee
that a copy won't be performed.
Returns
-------
``tf.Tensor``
Tensor object corresponding to a dense subset of data from the
base array
"""
if src.tf_indices is None:
raise BuildError("Indices for %s have not been loaded into "
"TensorFlow" % src)
logger.debug("gather")
logger.debug("src %s", src)
logger.debug("indices %s", src.indices)
logger.debug("src base %s", self.bases[src.key])
var = self.bases[src.key]
# we prefer to get the data via `strided_slice` or `identity` if
# possible, as it is more efficient
if force_copy or src.as_slice is None:
result = tf.gather(var, src.tf_indices)
elif (src.indices[0] == 0 and
src.indices[-1] == var.get_shape()[0].value - 1 and
len(src.indices) == var.get_shape()[0]):
result = var
else:
result = tf.strided_slice(var, *src.as_slice)
# for some reason the shape inference doesn't work in some cases
result.set_shape(src.tf_indices.get_shape()[:1].concatenate(
var.get_shape()[1:]))
# reshape the data according to the shape set in `src`, if there is
# one, otherwise keep the shape of the base array
if result.get_shape() != src.full_shape:
result = tf.reshape(result, src.tf_shape)
result.set_shape(src.full_shape)
# whenever we read from an array we use this to mark it as "read"
# (so that any future writes to the array will be scheduled after
# the read)
self.mark_gather(src)
return result
评论列表
文章目录