def m_mpnn(self, h_v, h_w, e_vw, opt={}):
# Matrices for each edge
edge_output = self.learn_modules[0](e_vw)
edge_output = edge_output.view(-1, self.args['out'], self.args['in'])
h_w_rows = h_w[..., None].expand(h_w.size(0), h_v.size(1), h_w.size(1)).contiguous()
h_w_rows = h_w_rows.view(-1, self.args['in'])
h_multiply = torch.bmm(edge_output, torch.unsqueeze(h_w_rows,2))
m_new = torch.squeeze(h_multiply)
return m_new
评论列表
文章目录