def forward(self, src, src_pos, tgt, tgt_pos):
tgt, tgt_pos = tgt[:, :-1], tgt_pos[:, :-1]
enc_outputs = self.enc(src, src_pos)
dec_output = self.dec(enc_outputs, src, tgt, tgt_pos)
out = self.linear(dec_output)
return F.log_softmax(out.view(-1, self.dec_vocab_size))
评论列表
文章目录