test_utils.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_gpu(self):
        compile_extension(
                name='gpulib',
                header=test_dir + '/ffi/src/cuda/cudalib.h',
                sources=[
                    test_dir + '/ffi/src/cuda/cudalib.c',
                ],
                with_cuda=True,
                verbose=False,
        )
        import gpulib
        tensor = torch.ones(2, 2).float()

        gpulib.good_func(tensor, 2, 1.5)
        self.assertEqual(tensor, torch.ones(2, 2) * 2 + 1.5)

        ctensor = tensor.cuda().fill_(1)
        gpulib.cuda_func(ctensor, 2, 1.5)
        self.assertEqual(ctensor, torch.ones(2, 2) * 2 + 1.5)

        self.assertRaises(TypeError,
                lambda: gpulib.cuda_func(tensor, 2, 1.5))
        self.assertRaises(TypeError,
                lambda: gpulib.cuda_func(ctensor.storage(), 2, 1.5))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号