signals.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:nengo_dl 作者: nengo 项目源码 文件源码
def gather(self, src, force_copy=False):
        """Fetches the data corresponding to ``src`` from the base array.

        Parameters
        ----------
        src : :class:`.TensorSignal`
            Signal indicating the data to be read from base array
        force_copy : bool, optional
            If True, always perform a gather, not a slice (this forces a
            copy). Note that setting ``force_copy=False`` does not guarantee
            that a copy won't be performed.

        Returns
        -------
        ``tf.Tensor``
            Tensor object corresponding to a dense subset of data from the
            base array
        """

        if src.tf_indices is None:
            raise BuildError("Indices for %s have not been loaded into "
                             "TensorFlow" % src)

        logger.debug("gather")
        logger.debug("src %s", src)
        logger.debug("indices %s", src.indices)
        logger.debug("src base %s", self.bases[src.key])

        var = self.bases[src.key]

        # we prefer to get the data via `strided_slice` or `identity` if
        # possible, as it is more efficient
        if force_copy or src.as_slice is None:
            result = tf.gather(var, src.tf_indices)
        elif (src.indices[0] == 0 and
              src.indices[-1] == var.get_shape()[0].value - 1 and
              len(src.indices) == var.get_shape()[0]):
            result = var
        else:
            result = tf.strided_slice(var, *src.as_slice)

        # for some reason the shape inference doesn't work in some cases
        result.set_shape(src.tf_indices.get_shape()[:1].concatenate(
            var.get_shape()[1:]))

        # reshape the data according to the shape set in `src`, if there is
        # one, otherwise keep the shape of the base array
        if result.get_shape() != src.full_shape:
            result = tf.reshape(result, src.tf_shape)
            result.set_shape(src.full_shape)

        # whenever we read from an array we use this to mark it as "read"
        # (so that any future writes to the array will be scheduled after
        # the read)
        self.mark_gather(src)

        return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号