test_poutines.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_do_propagation(self):
        pyro.clear_param_store()

        def model():
            z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1)))
            latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1))
            flip = pyro.sample("flip", Bernoulli(latent_prob))
            return flip

        sample_from_model = model()
        z_data = {"z": -10.0 * ng_ones(1)}
        # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0
        sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))()

        assert eq(sample_from_model, ng_ones(1))
        assert eq(sample_from_do_model, ng_zeros(1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号