Transformer.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:OpenNMT-py 作者: OpenNMT 项目源码 文件源码
def forward(self, input, context, src_pad_mask, tgt_pad_mask):
        # Args Checks
        input_batch, input_len, _ = input.size()
        contxt_batch, contxt_len, _ = context.size()
        aeq(input_batch, contxt_batch)

        src_batch, t_len, s_len = src_pad_mask.size()
        tgt_batch, t_len_, t_len__ = tgt_pad_mask.size()
        aeq(input_batch, contxt_batch, src_batch, tgt_batch)
        aeq(t_len, t_len_, t_len__, input_len)
        aeq(s_len, contxt_len)
        # END Args Checks

        dec_mask = torch.gt(tgt_pad_mask + self.mask[:, :tgt_pad_mask.size(1),
                            :tgt_pad_mask.size(1)]
                            .expand_as(tgt_pad_mask), 0)
        input_norm = self.layer_norm_1(input)
        query, attn = self.self_attn(input_norm, input_norm, input_norm,
                                     mask=dec_mask)
        query_norm = self.layer_norm_2(query+input)
        mid, attn = self.context_attn(context, context, query_norm,
                                      mask=src_pad_mask)
        output = self.feed_forward(mid+query+input)

        # CHECKS
        output_batch, output_len, _ = output.size()
        aeq(input_len, output_len)
        aeq(contxt_batch, output_batch)

        n_batch_, t_len_, s_len_ = attn.size()
        aeq(input_batch, n_batch_)
        aeq(contxt_len, s_len_)
        aeq(input_len, t_len_)
        # END CHECKS

        return output, attn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号