test_enum.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_iter_discrete_traces_scalar(graph_type):
    pyro.clear_param_store()

    def model():
        p = pyro.param("p", Variable(torch.Tensor([0.05])))
        ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4])))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    p = pyro.param("p").data
    ps = pyro.param("ps").data
    assert len(traces) == 2 * len(ps)

    for scale, trace in traces:
        x = trace.nodes["x"]["value"].data.long().view(-1)[0]
        y = trace.nodes["y"]["value"].data.long().view(-1)[0]
        expected_scale = Variable(torch.Tensor([[1 - p[0], p[0]][x] * ps[y]]))
        assert_equal(scale, expected_scale)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号