def test_sample(dist):
for idx in range(len(dist.dist_params)):
# Compute CPU value.
with tensors_default_to("cpu"):
params = dist.get_dist_params(idx)
try:
cpu_value = dist.pyro_dist.sample(**params)
except ValueError as e:
pytest.xfail('CPU version fails: {}'.format(e))
assert not cpu_value.is_cuda
# Compute GPU value.
with tensors_default_to("cuda"):
params = dist.get_dist_params(idx)
cuda_value = dist.pyro_dist.sample(**params)
assert cuda_value.is_cuda
assert_equal(cpu_value.size(), cuda_value.size())
评论列表
文章目录