MessageFunction.py 文件源码

python
阅读 51 收藏 0 点赞 0 评论 0

项目:nmp_qc 作者: priba 项目源码 文件源码
def m_mpnn(self, h_v, h_w, e_vw, opt={}):
        # Matrices for each edge
        edge_output = self.learn_modules[0](e_vw)
        edge_output = edge_output.view(-1, self.args['out'], self.args['in'])

        h_w_rows = h_w[..., None].expand(h_w.size(0), h_v.size(1), h_w.size(1)).contiguous()

        h_w_rows = h_w_rows.view(-1, self.args['in'])

        h_multiply = torch.bmm(edge_output, torch.unsqueeze(h_w_rows,2))

        m_new = torch.squeeze(h_multiply)

        return m_new
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号