def forward(self, input1, input2):
if (
not self.training and # test mode
self.beam_size is not None and # beam size is set
input1.dim() == 3 and # only support batched input
input1.size(1) == 1 # single time step update
):
bsz, beam = input1.size(0), self.beam_size
# bsz x 1 x nhu --> bsz/beam x beam x nhu
input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1)
# bsz x sz2 x nhu --> bsz/beam x sz2 x nhu
input2 = input2.unfold(0, beam, beam)[:, :, :, 0]
# use non batched operation if bsz = beam
if input1.size(0) == 1:
output = torch.mm(input1[0, :, :], input2[0, :, :])
else:
output = input1.bmm(input2)
return output.view(bsz, 1, -1)
else:
return input1.bmm(input2)
评论列表
文章目录