def test_iter_discrete_traces_vector(graph_type):
pyro.clear_param_store()
def model():
p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4],
[0.4, 0.3, 0.2, 0.1]])))
x = pyro.sample("x", dist.Bernoulli(p))
y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
assert x.size() == (2, 1)
assert y.size() == (2, 1)
return dict(x=x, y=y)
traces = list(iter_discrete_traces(graph_type, model))
p = pyro.param("p").data
ps = pyro.param("ps").data
assert len(traces) == 2 * ps.size(-1)
for scale, trace in traces:
x = trace.nodes["x"]["value"].data.squeeze().long()[0]
y = trace.nodes["y"]["value"].data.squeeze().long()[0]
expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) *
dist.Categorical(ps, one_hot=False).log_pdf(y))
expected_scale = expected_scale.data.view(-1)[0]
assert_equal(scale, expected_scale)
评论列表
文章目录