def new_att_module(self):
class NewAttModule(nn.Module):
def __init__(self):
super(NewAttModule, self).__init__()
def forward(self, linput, rinput):
self.lPad = linput.view(-1, linput.size(0), linput.size(1))
self.lPad = linput # self.lPad = Padding(0, 0)(linput) TODO: figureout why padding?
self.M_r = torch.mm(self.lPad, rinput.t())
self.alpha = F.softmax(self.M_r.transpose(0, 1))
self.Yl = torch.mm(self.alpha, self.lPad)
return self.Yl
att_module = NewAttModule()
if getattr(self, "att_module_master", None):
for (tar_param, src_param) in zip(att_module.parameters(), self.att_module_master.parameters()):
tar_param.grad.data = src_param.grad.data.clone()
return att_module
评论列表
文章目录