def _maybe_merge_batch_beams(self, t, s):
"""Splits the tensor from a batch by beams into a batch of beams.
More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
reshape this into [batch_size, beam_width, s]
Args:
t: Tensor of dimension [batch_size*beam_width, s]
s: Tensor, Python int, or TensorShape.
Returns:
A reshaped version of t with dimension [batch_size, beam_width, s].
Raises:
TypeError: If t is an instance of TensorArray.
ValueError: If the rank of t is not statically known.
"""
return self._merge_batch_beams(t, s) if t.shape.ndims >= 2 else t
beam_aligner.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录