def u_mpnn(self, h_v, m_v, opt={}):
h_in = h_v.view(-1,h_v.size(2))
m_in = m_v.view(-1,m_v.size(2))
h_new = self.learn_modules[0](m_in[None,...],h_in[None,...])[0] # 0 or 1???
return torch.squeeze(h_new).view(h_v.size())
评论列表
文章目录