def do_elbo_test(self, reparameterized, n_steps):
if self.verbose:
print(" - - - - - DO NORMALNORMAL ELBO TEST [reparameterized = %s] - - - - - " % reparameterized)
pyro.clear_param_store()
def model():
mu_latent = pyro.sample(
"mu_latent",
dist.Normal(self.mu0, torch.pow(self.lam0, -0.5), reparameterized=reparameterized))
for i, x in enumerate(self.data):
pyro.observe("obs_%d" % i, dist.normal, x, mu_latent,
torch.pow(self.lam, -0.5))
return mu_latent
def guide():
mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.334 * torch.ones(2),
requires_grad=True))
log_sig_q = pyro.param("log_sig_q", Variable(
self.analytic_log_sig_n.data - 0.29 * torch.ones(2),
requires_grad=True))
sig_q = torch.exp(log_sig_q)
mu_latent = pyro.sample("mu_latent",
dist.Normal(mu_q, sig_q, reparameterized=reparameterized),
baseline=dict(use_decaying_avg_baseline=True))
return mu_latent
adam = optim.Adam({"lr": .0015, "betas": (0.97, 0.999)})
svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)
for k in range(n_steps):
svi.step()
mu_error = param_mse("mu_q", self.analytic_mu_n)
log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n)
if k % 250 == 0 and self.verbose:
print("mu error, log(sigma) error: %.4f, %.4f" % (mu_error, log_sig_error))
self.assertEqual(0.0, mu_error, prec=0.03)
self.assertEqual(0.0, log_sig_error, prec=0.03)
评论列表
文章目录