graph_convolution.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:chainer-graph-cnn 作者: pfnet-research 项目源码 文件源码
def forward_gpu(self, inputs):
        x, W = inputs[:2]
        n_batch, c_in, N = x.shape
        b = inputs[2] if len(inputs) == 3 else None
        xp = cuda.get_array_module(x)
        with cuda.get_device(x.data):
            K = self.K
            LmI_data, LmI_indices, LmI_indptr = self.LmI_tuple

            if x.dtype != LmI_data.dtype:
                LmI_data = LmI_data.astype(x.dtype)

            C = xp.empty((K, N, c_in, n_batch), dtype=x.dtype)
            chebyshev_matvec_gpu(C, x, K, n_batch,
                                 LmI_data, LmI_indices, LmI_indptr)

            C = C.transpose((3, 2, 0, 1))
            self.C = C
            y = xp.tensordot(C, W, ((1, 2), (1, 2)))

            if b is not None:
                y += b

            return xp.rollaxis(y, 2, 1),  # y.shape = (n_batch, c_out, N)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号