def _tensor_gather_helper(gather_indices, gather_from, batch_size,
range_size, gather_shape):
"""Helper for gathering the right indices from the tensor.
This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
gathering from that according to the gather_indices, which are offset by
the right amounts in order to preserve the batch order.
Args:
gather_indices: The tensor indices that we use to gather.
gather_from: The tensor that we are gathering from.
batch_size: The input batch size.
range_size: The number of values in each range. Likely equal to beam_width.
gather_shape: What we should reshape gather_from to in order to preserve the
correct values. An example is when gather_from is the attention from an
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
There, we want to preserve the attention_size elements, so gather_shape is
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
attention_size as desired.
Returns:
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
"""
range_ = tf.expand_dims(tf.range(batch_size) * range_size, 1)
gather_indices = tf.reshape(gather_indices + range_, [-1])
output = tf.gather(tf.reshape(gather_from, gather_shape), gather_indices)
final_shape = tf.shape(gather_from)[:1 + len(gather_shape)]
final_static_shape = (tf.TensorShape([None]).concatenate(gather_from.shape[1:1 + len(gather_shape)]))
output = tf.reshape(output, final_shape)
output.set_shape(final_static_shape)
return output
beam_aligner.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录