def do_elbo_test(self, reparameterized, n_steps):
pyro.clear_param_store()
def model():
mu_latent = pyro.sample("mu_latent", dist.normal,
self.mu0, torch.pow(self.lam0, -0.5))
pyro.map_data("aaa", self.data, lambda i,
x: pyro.observe(
"obs_%d" % i, dist.normal,
x, mu_latent, torch.pow(self.lam, -0.5)),
batch_size=self.batch_size)
return mu_latent
def guide():
mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2),
requires_grad=True))
log_sig_q = pyro.param("log_sig_q", Variable(
self.analytic_log_sig_n.data - 0.14 * torch.ones(2),
requires_grad=True))
sig_q = torch.exp(log_sig_q)
pyro.sample("mu_latent", dist.Normal(mu_q, sig_q, reparameterized=reparameterized))
pyro.map_data("aaa", self.data, lambda i, x: None,
batch_size=self.batch_size)
adam = optim.Adam({"lr": .001})
svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False)
for k in range(n_steps):
svi.step()
mu_error = param_mse("mu_q", self.analytic_mu_n)
log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n)
self.assertEqual(0.0, mu_error, prec=0.05)
self.assertEqual(0.0, log_sig_error, prec=0.05)
评论列表
文章目录