def _merge_batch_beams(self, t, s):
"""Merges the tensor from a batch of beams into a batch by beams.
More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We
reshape this into [batch_size*beam_width, s]
Args:
t: Tensor of dimension [batch_size, beam_width, s]
Returns:
A reshaped version of t with dimension [batch_size * beam_width, s].
"""
t_shape = tf.shape(t)
reshaped = tf.reshape(t, tf.concat(([self._batch_size * self._beam_width], t_shape[2:]), axis=0))
reshaped.set_shape(tf.TensorShape([None]).concatenate(s))
return reshaped
beam_aligner.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录