def _beam_where(self, cond, x, y):
assert x.shape.is_compatible_with(y.shape)
original_static_shape = x.shape
cond = tf.reshape(cond, [self.batch_size * self._beam_width])
x = self._merge_batch_beams(x, original_static_shape[2:])
y = self._merge_batch_beams(y, original_static_shape[2:])
return self._split_batch_beams(tf.where(cond, x, y), original_static_shape[2:])
beam_aligner.py 文件源码
python
阅读 39
收藏 0
点赞 0
评论 0
评论列表
文章目录