def test_iter_discrete_traces_scalar(graph_type):
pyro.clear_param_store()
def model():
p = pyro.param("p", Variable(torch.Tensor([0.05])))
ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4])))
x = pyro.sample("x", dist.Bernoulli(p))
y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
return dict(x=x, y=y)
traces = list(iter_discrete_traces(graph_type, model))
p = pyro.param("p").data
ps = pyro.param("ps").data
assert len(traces) == 2 * len(ps)
for scale, trace in traces:
x = trace.nodes["x"]["value"].data.long().view(-1)[0]
y = trace.nodes["y"]["value"].data.long().view(-1)[0]
expected_scale = Variable(torch.Tensor([[1 - p[0], p[0]][x] * ps[y]]))
assert_equal(scale, expected_scale)
评论列表
文章目录