def forward(self, enc_outputs, enc_input, dec_input, dec_pos):
dec_output = self.dec_ebd(dec_input) + self.pos_ebd(dec_pos)
dec_slf_attn_mask = torch.gt(
get_attn_padding_mask(dec_input, dec_input) + get_attn_subsequent_mask(dec_input), 0)
dec_enc_attn_pad_mask = get_attn_padding_mask(dec_input, enc_input)
for layer, enc_output in zip(self.decodes, enc_outputs):
dec_output = layer(dec_output, enc_output,
dec_slf_attn_mask, dec_enc_attn_pad_mask)
return dec_output
评论列表
文章目录