def test_random_module(nn_module):
pyro.clear_param_store()
nn_module = nn_module()
p = Variable(torch.ones(2, 2))
prior = dist.Bernoulli(p)
lifted_mod = pyro.random_module("module", nn_module, prior)
nn_module = lifted_mod()
for name, parameter in nn_module.named_parameters():
assert torch.equal(torch.ones(2, 2), parameter.data)
评论列表
文章目录