def test_elbo_with_transformed_distribution(self):
if self.verbose:
print(" - - - - - DO LOGNORMAL-NORMAL ELBO TEST [uses TransformedDistribution] - - - - - ")
pyro.clear_param_store()
def model():
mu_latent = pyro.sample("mu_latent", dist.normal,
self.mu0, torch.pow(self.tau0, -0.5))
bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent)
x_dist = TransformedDistribution(dist.normal, bijector)
pyro.observe("obs0", x_dist, self.data[0], ng_zeros(1), ng_ones(1))
pyro.observe("obs1", x_dist, self.data[1], ng_zeros(1), ng_ones(1))
return mu_latent
def guide():
mu_q_log = pyro.param(
"mu_q_log",
Variable(
self.log_mu_n.data +
0.17,
requires_grad=True))
tau_q_log = pyro.param("tau_q_log", Variable(self.log_tau_n.data - 0.143,
requires_grad=True))
mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log)
pyro.sample("mu_latent", dist.normal, mu_q, torch.pow(tau_q, -0.5))
adam = optim.Adam({"lr": 0.001, "betas": (0.95, 0.999)})
svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)
for k in range(7000):
svi.step()
mu_error = param_abs_error("mu_q_log", self.log_mu_n)
tau_error = param_abs_error("tau_q_log", self.log_tau_n)
if k % 500 == 0 and self.verbose:
print("mu_error, tau_error = %.4f, %.4f" % (mu_error, tau_error))
self.assertEqual(0.0, mu_error, prec=0.05)
self.assertEqual(0.0, tau_error, prec=0.05)
评论列表
文章目录