def test_nested_map_data():
means = [Variable(torch.randn(2)) for i in range(8)]
mean_batch_size = 2
stds = [Variable(torch.abs(torch.randn(2))) for i in range(6)]
std_batch_size = 3
def model(means, stds):
return pyro.map_data("a", means,
lambda i, x:
pyro.map_data("a_{}".format(i), stds,
lambda j, y:
pyro.sample("x_{}{}".format(i, j),
dist.normal, x, y),
batch_size=std_batch_size),
batch_size=mean_batch_size)
model = model
xs = model(means, stds)
assert len(xs) == mean_batch_size
assert len(xs[0]) == std_batch_size
tr = poutine.trace(model).get_trace(means, stds)
for name in tr.nodes.keys():
if tr.nodes[name]["type"] == "sample" and name.startswith("x_"):
assert tr.nodes[name]["scale"] == 4.0 * 2.0
评论列表
文章目录