test_enum.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_bern_elbo_gradient(enum_discrete, trace_graph):
    pyro.clear_param_store()
    num_particles = 2000

    def model():
        p = Variable(torch.Tensor([0.25]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    print("Computing gradients using surrogate loss")
    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete,
                num_particles=(1 if enum_discrete else num_particles))
    with xfail_if_not_implemented():
        elbo.loss_and_grads(model, guide)
    params = sorted(pyro.get_param_store().get_all_param_names())
    assert params, "no params found"
    actual_grads = {name: pyro.param(name).grad.clone() for name in params}

    print("Computing gradients using finite difference")
    elbo = Trace_ELBO(num_particles=num_particles)
    expected_grads = finite_difference(lambda: elbo.loss(model, guide))

    for name in params:
        print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
                                 expected_grads[name].data))
    assert_equal(actual_grads, expected_grads, prec=0.1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号