test_sampling.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def setUp(self):

        # simple Gaussian-emission HMM
        def model():
            p_latent = pyro.param("p1", Variable(torch.Tensor([[0.7], [0.3]])))
            p_obs = pyro.param("p2", Variable(torch.Tensor([[0.9], [0.1]])))

            latents = [Variable(torch.ones(1, 1))]
            observes = []
            for t in range(self.model_steps):

                latents.append(
                    pyro.sample("latent_{}".format(str(t)),
                                Bernoulli(torch.index_select(p_latent, 0, latents[-1].view(-1).long()))))

                observes.append(
                    pyro.observe("observe_{}".format(str(t)),
                                 Bernoulli(torch.index_select(p_obs, 0, latents[-1].view(-1).long())),
                                 self.data[t]))
            return torch.sum(torch.cat(latents))

        self.model_steps = 3
        self.data = [pyro.ones(1, 1) for _ in range(self.model_steps)]
        self.model = model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号