beamable_mm.py 文件源码

python
阅读 45 收藏 0 点赞 0 评论 0

项目:fairseq-py 作者: facebookresearch 项目源码 文件源码
def forward(self, input1, input2):
        if (
            not self.training and           # test mode
            self.beam_size is not None and  # beam size is set
            input1.dim() == 3 and           # only support batched input
            input1.size(1) == 1             # single time step update
        ):
            bsz, beam = input1.size(0), self.beam_size

            # bsz x 1 x nhu --> bsz/beam x beam x nhu
            input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1)

            # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu
            input2 = input2.unfold(0, beam, beam)[:, :, :, 0]

            # use non batched operation if bsz = beam
            if input1.size(0) == 1:
                output = torch.mm(input1[0, :, :], input2[0, :, :])
            else:
                output = input1.bmm(input2)
            return output.view(bsz, 1, -1)
        else:
            return input1.bmm(input2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号