test_inference.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def setUp(self):
        # normal-normal; known covariance
        def model_dup():
            pyro.param("mu_q", Variable(torch.ones(1), requires_grad=True))
            pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))

        def model_obs_dup():
            pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
            pyro.observe("mu_q", dist.normal, ng_zeros(1), ng_ones(1), ng_zeros(1))

        def model():
            pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))

        def guide():
            p = pyro.param("p", Variable(torch.ones(1), requires_grad=True))
            pyro.sample("mu_q", dist.normal, ng_zeros(1), p)
            pyro.sample("mu_q_2", dist.normal, ng_zeros(1), p)

        self.duplicate_model = model_dup
        self.duplicate_obs = model_obs_dup
        self.model = model
        self.guide = guide
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号