transform_filter.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:GrouPy 作者: tscohen 项目源码 文件源码
def forward_gpu(self, inputs):

        w, = inputs
        xp = cuda.get_array_module(w)
        och, ich, _, ny, nx = w.shape

        nto, nti = self.T.shape[:2]
        rotated_w = xp.empty((och, nto, ich, nti, ny, nx), dtype=w.dtype)

        index_group_func_kernel(
            input=w,
            T=self.T,
            U=self.U,
            V=self.V,
            output=rotated_w
        )

        return rotated_w,
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号