test_enum.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_iter_discrete_traces_nan(enum_discrete, trace_graph):
    pyro.clear_param_store()

    def model():
        p = Variable(torch.Tensor([0.0, 0.5, 1.0]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.0, 0.5, 1.0]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete)
    with xfail_if_not_implemented():
        loss = elbo.loss(model, guide)
        assert isinstance(loss, float) and not math.isnan(loss), loss
        loss = elbo.loss_and_grads(model, guide)
        assert isinstance(loss, float) and not math.isnan(loss), loss


# A simple Gaussian mixture model, with no vectorization.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号