tfutil.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:rltools 作者: sisl 项目源码 文件源码
def lookup_last_idx(a, inds, name=None):
    """
    Looks up indices in a. e.g. a[[1, 2, 3]] = [a[1], a[2], a[3]]
    a is a d1 x d2 ... dn tensor
    inds is a d1 x d2 ... d(n-1) tensor of integers
    returns the tensor
    out[i_1,...,i_{n-1}] = a[i_1,...,i_{n-1}, inds[i_1,...,i_{n-1}]]
    """
    with tf.op_scope([a, inds], name, 'lookup_last_idx') as scope:
        a = tf.convert_to_tensor(a, name='a')
        inds = tf.convert_to_tensor(inds, name='inds')

        # Flatten the arrays
        ashape, indsshape = tf.shape(a), tf.shape(inds)
        aflat, indsflat = tf.reshape(a, [-1]), tf.reshape(inds, [-1])

        # Compute the indices corresponding to inds in the flattened array
        # TODO Causes UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape.
        delta = tf.gather(ashape, tf.size(ashape) - 1)  # i.e. delta = ashape[-1],
        aflatinds = tf.range(0, limit=tf.size(a), delta=delta) + indsflat

        # Look up the desired elements in the flattened array, and reshape
        # to the original shape
        return tf.reshape(tf.gather(aflat, aflatinds), indsshape, name=scope)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号