def batch_gather(tensor, indices):
"""Gather in batch from a tensor of arbitrary size.
In pseduocode this module will produce the following:
output[i] = tf.gather(tensor[i], indices[i])
Args:
tensor: Tensor of arbitrary size.
indices: Vector of indices.
Returns:
output: A tensor of gathered values.
"""
shape = get_shape(tensor)
flat_first = tf.reshape(tensor, [shape[0] * shape[1]] + shape[2:])
indices = tf.convert_to_tensor(indices)
offset_shape = [shape[0]] + [1] * (indices.shape.ndims - 1)
offset = tf.reshape(tf.range(shape[0]) * shape[1], offset_shape)
output = tf.gather(flat_first, indices + offset)
return output
评论列表
文章目录