test_cuda.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_sample(dist):
    for idx in range(len(dist.dist_params)):

        # Compute CPU value.
        with tensors_default_to("cpu"):
            params = dist.get_dist_params(idx)
        try:
            cpu_value = dist.pyro_dist.sample(**params)
        except ValueError as e:
            pytest.xfail('CPU version fails: {}'.format(e))
        assert not cpu_value.is_cuda

        # Compute GPU value.
        with tensors_default_to("cuda"):
            params = dist.get_dist_params(idx)
        cuda_value = dist.pyro_dist.sample(**params)
        assert cuda_value.is_cuda

        assert_equal(cpu_value.size(), cuda_value.size())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号