def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):
# Word embedding look up
dec_input = self.tgt_word_emb(tgt_seq)
# Position Encoding addition
dec_input += self.position_enc(tgt_pos)
# Decode
dec_slf_attn_pad_mask = get_attn_padding_mask(tgt_seq, tgt_seq)
dec_slf_attn_sub_mask = get_attn_subsequent_mask(tgt_seq)
dec_slf_attn_mask = torch.gt(dec_slf_attn_pad_mask + dec_slf_attn_sub_mask, 0)
dec_enc_attn_pad_mask = get_attn_padding_mask(tgt_seq, src_seq)
if return_attns:
dec_slf_attns, dec_enc_attns = [], []
dec_output = dec_input
for dec_layer in self.layer_stack:
dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
dec_output, enc_output,
slf_attn_mask=dec_slf_attn_mask,
dec_enc_attn_mask=dec_enc_attn_pad_mask)
if return_attns:
dec_slf_attns += [dec_slf_attn]
dec_enc_attns += [dec_enc_attn]
if return_attns:
return dec_output, dec_slf_attns, dec_enc_attns
else:
return dec_output,
Models.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录