def gmm_guide(data, verbose=False):
for i in pyro.irange("data", len(data)):
p = pyro.param("p_{}".format(i), Variable(torch.Tensor([0.6]), requires_grad=True))
z = pyro.sample("z_{}".format(i), dist.Bernoulli(p))
assert z.size() == (1,)
z = z.long().data[0]
if verbose:
print("G{} z_{} = {}".format(" " * i, i, z))
评论列表
文章目录