def test_do_propagation(self):
pyro.clear_param_store()
def model():
z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1)))
latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1))
flip = pyro.sample("flip", Bernoulli(latent_prob))
return flip
sample_from_model = model()
z_data = {"z": -10.0 * ng_ones(1)}
# under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0
sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))()
assert eq(sample_from_model, ng_ones(1))
assert eq(sample_from_do_model, ng_zeros(1))
评论列表
文章目录