test_cuda.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5, force_gpu_half=False):
    def tmp(self):
        cpu_tensor = tensor_constructor(t)
        type_map = {}
        if force_gpu_half:
            type_map = {
                'torch.FloatTensor': 'torch.cuda.HalfTensor',
                'torch.DoubleTensor': 'torch.cuda.HalfTensor',
            }
        gpu_tensor = to_gpu(cpu_tensor, type_map)
        cpu_args = arg_constructor(t)
        gpu_args = [to_gpu(arg, type_map) for arg in cpu_args]
        cpu_result = getattr(cpu_tensor, fn)(*cpu_args)
        try:
            gpu_result = getattr(gpu_tensor, fn)(*gpu_args)
        except RuntimeError as e:
            reason = e.args[0]
            if 'only supports floating-point types' in reason or 'unimplemented data type' in reason:
                raise unittest.SkipTest('unimplemented data type')
            raise
        except AttributeError as e:
            reason = e.args[0]
            if 'object has no attribute' in reason:
                raise unittest.SkipTest('unimplemented data type')
            raise
        # If one changes, another should change as well
        self.assertEqual(cpu_tensor, gpu_tensor, precision)
        self.assertEqual(cpu_args, gpu_args, precision)
        # Compare results
        self.assertEqual(cpu_result, gpu_result, precision)
    return tmp
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号