test_sampling.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def setUp(self):

        pyro.clear_param_store()

        def model():
            mu = pyro.sample("mu", Normal(Variable(torch.zeros(1)),
                                          Variable(torch.ones(1))))
            xd = Normal(mu, Variable(torch.ones(1)), batch_size=50)
            pyro.observe("xs", xd, self.data)
            return mu

        def guide():
            return pyro.sample("mu", Normal(Variable(torch.zeros(1)),
                                            Variable(torch.ones(1))))

        # data
        self.data = Variable(torch.zeros(50, 1))
        self.mu_mean = Variable(torch.zeros(1))
        self.mu_stddev = torch.sqrt(Variable(torch.ones(1)) / 51.0)

        # model and guide
        self.model = model
        self.guide = guide
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号