def test_gmm_elbo_gradient(model, guide, enum_discrete, trace_graph):
pyro.clear_param_store()
num_particles = 4000
data = Variable(torch.Tensor([-1, 1]))
print("Computing gradients using surrogate loss")
Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
elbo = Elbo(enum_discrete=enum_discrete,
num_particles=(1 if enum_discrete else num_particles))
with xfail_if_not_implemented():
elbo.loss_and_grads(model, guide, data)
params = sorted(pyro.get_param_store().get_all_param_names())
assert params, "no params found"
actual_grads = {name: pyro.param(name).grad.clone() for name in params}
print("Computing gradients using finite difference")
elbo = Trace_ELBO(num_particles=num_particles)
expected_grads = finite_difference(lambda: elbo.loss(model, guide, data))
for name in params:
print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
expected_grads[name].data))
assert_equal(actual_grads, expected_grads, prec=0.5)
评论列表
文章目录