compAggWikiqa.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:SeqMatchSeq 作者: pcgreat 项目源码 文件源码
def new_att_module(self):

        class NewAttModule(nn.Module):
            def __init__(self):
                super(NewAttModule, self).__init__()

            def forward(self, linput, rinput):
                self.lPad = linput.view(-1, linput.size(0), linput.size(1))

                self.lPad = linput  # self.lPad = Padding(0, 0)(linput) TODO: figureout why padding?
                self.M_r = torch.mm(self.lPad, rinput.t())
                self.alpha = F.softmax(self.M_r.transpose(0, 1))
                self.Yl = torch.mm(self.alpha, self.lPad)
                return self.Yl

        att_module = NewAttModule()
        if getattr(self, "att_module_master", None):
            for (tar_param, src_param) in zip(att_module.parameters(), self.att_module_master.parameters()):
                tar_param.grad.data = src_param.grad.data.clone()
        return att_module
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号