def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
range_size, gather_shape):
"""Maybe applies _tensor_gather_helper.
This applies _tensor_gather_helper when the gather_from dims is at least as
big as the length of gather_shape. This is used in conjunction with nest so
that we don't apply _tensor_gather_helper to inapplicable values like scalars.
Args:
gather_indices: The tensor indices that we use to gather.
gather_from: The tensor that we are gathering from.
batch_size: The batch size.
range_size: The number of values in each range. Likely equal to beam_width.
gather_shape: What we should reshape gather_from to in order to preserve the
correct values. An example is when gather_from is the attention from an
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
There, we want to preserve the attention_size elements, so gather_shape is
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
attention_size as desired.
Returns:
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
or the original tensor if its dimensions are too small.
"""
if gather_from.shape.ndims >= len(gather_shape):
return _tensor_gather_helper(
gather_indices=gather_indices,
gather_from=gather_from,
batch_size=batch_size,
range_size=range_size,
gather_shape=gather_shape)
else:
return gather_from
beam_aligner.py 文件源码
python
阅读 37
收藏 0
点赞 0
评论 0
评论列表
文章目录