beam_aligner.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def _split_batch_beams(self, t, s):
        """Splits the tensor from a batch by beams into a batch of beams.
        More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
        reshape this into [batch_size, beam_width, s]
        Args:
          t: Tensor of dimension [batch_size*beam_width, s].
          s: (Possibly known) depth shape.
        Returns:
          A reshaped version of t with dimension [batch_size, beam_width, s].
        Raises:
          ValueError: If, after reshaping, the new tensor is not shaped
            `[batch_size, beam_width, s]` (assuming batch_size and beam_width
            are known statically).
        """
        t_shape = tf.shape(t)
        reshaped = tf.reshape(t, tf.concat(([self._batch_size, self._beam_width], t_shape[1:]), axis=0))
        reshaped.set_shape(tf.TensorShape([None, self._beam_width]).concatenate(t.shape[1:]))
        expected_reshaped_shape = tf.TensorShape([None, self._beam_width]).concatenate(s)
        if not reshaped.shape.is_compatible_with(expected_reshaped_shape):
            raise ValueError("Unexpected behavior when reshaping between beam width "
                             "and batch size.  The reshaped tensor has shape: %s.  "
                             "We expected it to have shape "
                             "(batch_size, beam_width, depth) == %s.  Perhaps you "
                             "forgot to create a zero_state with "
                             "batch_size=encoder_batch_size * beam_width?"
                             % (reshaped.shape, expected_reshaped_shape))
        return reshaped
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号