def gmm_model(data, verbose=False):
p = pyro.param("p", Variable(torch.Tensor([0.3]), requires_grad=True))
sigma = pyro.param("sigma", Variable(torch.Tensor([1.0]), requires_grad=True))
mus = Variable(torch.Tensor([-1, 1]))
for i in pyro.irange("data", len(data)):
z = pyro.sample("z_{}".format(i), dist.Bernoulli(p))
assert z.size() == (1,)
z = z.long().data[0]
if verbose:
print("M{} z_{} = {}".format(" " * i, i, z))
pyro.observe("x_{}".format(i), dist.Normal(mus[z], sigma), data[i])
评论列表
文章目录