def test_batch_dim(batch_dim):
data = Variable(torch.randn(4, 5, 7))
def local_model(ixs, _xs):
xs = _xs.view(-1, _xs.size(2))
return pyro.sample("xs", dist.normal,
xs, Variable(torch.ones(xs.size())))
def model():
return pyro.map_data("md", data, local_model,
batch_size=1, batch_dim=batch_dim)
tr = poutine.trace(model).get_trace()
assert tr.nodes["xs"]["value"].size(0) == data.size(1 - batch_dim)
assert tr.nodes["xs"]["value"].size(1) == data.size(2)
评论列表
文章目录