model.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def forward(self, enc_outputs, enc_input, dec_input, dec_pos):
        dec_output = self.dec_ebd(dec_input) + self.pos_ebd(dec_pos)

        dec_slf_attn_mask = torch.gt(
            get_attn_padding_mask(dec_input, dec_input) + get_attn_subsequent_mask(dec_input), 0)
        dec_enc_attn_pad_mask = get_attn_padding_mask(dec_input, enc_input)

        for layer, enc_output in zip(self.decodes, enc_outputs):
            dec_output = layer(dec_output, enc_output,
                dec_slf_attn_mask, dec_enc_attn_pad_mask)

        return dec_output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号