test_inference.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_elbo_with_transformed_distribution(self):
        pyro.clear_param_store()

        def model():
            zero = Variable(torch.zeros(1))
            one = Variable(torch.ones(1))
            mu_latent = pyro.sample("mu_latent", dist.normal,
                                    self.mu0, torch.pow(self.tau0, -0.5))
            bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent)
            x_dist = TransformedDistribution(dist.normal, bijector)
            pyro.observe("obs0", x_dist, self.data[0], zero, one)
            pyro.observe("obs1", x_dist, self.data[1], zero, one)
            return mu_latent

        def guide():
            mu_q_log = pyro.param(
                "mu_q_log",
                Variable(
                    self.log_mu_n.data +
                    0.17,
                    requires_grad=True))
            tau_q_log = pyro.param("tau_q_log", Variable(self.log_tau_n.data - 0.143,
                                                         requires_grad=True))
            mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log)
            pyro.sample("mu_latent", dist.normal, mu_q, torch.pow(tau_q, -0.5))

        adam = optim.Adam({"lr": .0005, "betas": (0.96, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False)

        for k in range(12001):
            svi.step()

        mu_error = param_abs_error("mu_q_log", self.log_mu_n)
        tau_error = param_abs_error("tau_q_log", self.log_tau_n)
        self.assertEqual(0.0, mu_error, prec=0.05)
        self.assertEqual(0.0, tau_error, prec=0.05)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号