test_tracegraph_elbo.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_elbo_nonreparameterized(self):
        if self.verbose:
            print(" - - - - - DO BERNOULLI-BETA ELBO TEST - - - - - ")
        pyro.clear_param_store()

        def model():
            p_latent = pyro.sample("p_latent", dist.beta, self.alpha0, self.beta0)
            for i, x in enumerate(self.data):
                pyro.observe("obs_{}".format(i), dist.bernoulli, x,
                             torch.pow(torch.pow(p_latent, 2.0), 0.5))
            return p_latent

        def guide():
            alpha_q_log = pyro.param("alpha_q_log",
                                     Variable(self.log_alpha_n.data + 0.17, requires_grad=True))
            beta_q_log = pyro.param("beta_q_log",
                                    Variable(self.log_beta_n.data - 0.143, requires_grad=True))
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            p_latent = pyro.sample("p_latent", dist.beta, alpha_q, beta_q,
                                   baseline=dict(use_decaying_avg_baseline=True))
            return p_latent

        adam = optim.Adam({"lr": .0007, "betas": (0.96, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)

        for k in range(3000):
            svi.step()

            alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n)
            beta_error = param_abs_error("beta_q_log", self.log_beta_n)

            if k % 500 == 0 and self.verbose:
                print("alpha_error, beta_error: %.4f, %.4f" % (alpha_error, beta_error))

        self.assertEqual(0.0, alpha_error, prec=0.03)
        self.assertEqual(0.0, beta_error, prec=0.04)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号