def lookup_last_idx(a, inds, name=None):
"""
Looks up indices in a. e.g. a[[1, 2, 3]] = [a[1], a[2], a[3]]
a is a d1 x d2 ... dn tensor
inds is a d1 x d2 ... d(n-1) tensor of integers
returns the tensor
out[i_1,...,i_{n-1}] = a[i_1,...,i_{n-1}, inds[i_1,...,i_{n-1}]]
"""
with tf.op_scope([a, inds], name, 'lookup_last_idx') as scope:
a = tf.convert_to_tensor(a, name='a')
inds = tf.convert_to_tensor(inds, name='inds')
# Flatten the arrays
ashape, indsshape = tf.shape(a), tf.shape(inds)
aflat, indsflat = tf.reshape(a, [-1]), tf.reshape(inds, [-1])
# Compute the indices corresponding to inds in the flattened array
# TODO Causes UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape.
delta = tf.gather(ashape, tf.size(ashape) - 1) # i.e. delta = ashape[-1],
aflatinds = tf.range(0, limit=tf.size(a), delta=delta) + indsflat
# Look up the desired elements in the flattened array, and reshape
# to the original shape
return tf.reshape(tf.gather(aflat, aflatinds), indsshape, name=scope)
评论列表
文章目录