test_tracegraph_elbo.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def do_elbo_test(self, reparameterized, n_steps, beta1, lr):
        if self.verbose:
            print(" - - - - - DO LOGNORMAL-NORMAL ELBO TEST [repa = %s] - - - - - " % reparameterized)
        pyro.clear_param_store()
        pt_guide = LogNormalNormalGuide(self.log_mu_n.data + 0.17,
                                        self.log_tau_n.data - 0.143)

        def model():
            mu_latent = pyro.sample("mu_latent", dist.normal,
                                    self.mu0, torch.pow(self.tau0, -0.5))
            sigma = torch.pow(self.tau, -0.5)
            pyro.observe("obs0", dist.lognormal, self.data[0], mu_latent, sigma)
            pyro.observe("obs1", dist.lognormal, self.data[1], mu_latent, sigma)
            return mu_latent

        def guide():
            pyro.module("mymodule", pt_guide)
            mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log)
            sigma = torch.pow(tau_q, -0.5)
            pyro.sample("mu_latent",
                        dist.Normal(mu_q, sigma, reparameterized=reparameterized),
                        baseline=dict(use_decaying_avg_baseline=True))

        adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)

        for k in range(n_steps):
            svi.step()

            mu_error = param_abs_error("mymodule$$$mu_q_log", self.log_mu_n)
            tau_error = param_abs_error("mymodule$$$tau_q_log", self.log_tau_n)
            if k % 500 == 0 and self.verbose:
                print("mu_error, tau_error = %.4f, %.4f" % (mu_error, tau_error))

        self.assertEqual(0.0, mu_error, prec=0.05)
        self.assertEqual(0.0, tau_error, prec=0.05)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号