def _split_batch_beams(self, t, s):
"""Splits the tensor from a batch by beams into a batch of beams.
More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
reshape this into [batch_size, beam_width, s]
Args:
t: Tensor of dimension [batch_size*beam_width, s].
s: (Possibly known) depth shape.
Returns:
A reshaped version of t with dimension [batch_size, beam_width, s].
Raises:
ValueError: If, after reshaping, the new tensor is not shaped
`[batch_size, beam_width, s]` (assuming batch_size and beam_width
are known statically).
"""
t_shape = tf.shape(t)
reshaped = tf.reshape(t, tf.concat(([self._batch_size, self._beam_width], t_shape[1:]), axis=0))
reshaped.set_shape(tf.TensorShape([None, self._beam_width]).concatenate(t.shape[1:]))
expected_reshaped_shape = tf.TensorShape([None, self._beam_width]).concatenate(s)
if not reshaped.shape.is_compatible_with(expected_reshaped_shape):
raise ValueError("Unexpected behavior when reshaping between beam width "
"and batch size. The reshaped tensor has shape: %s. "
"We expected it to have shape "
"(batch_size, beam_width, depth) == %s. Perhaps you "
"forgot to create a zero_state with "
"batch_size=encoder_batch_size * beam_width?"
% (reshaped.shape, expected_reshaped_shape))
return reshaped
beam_aligner.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录