def gmm_batch_model(data):
p = pyro.param("p", Variable(torch.Tensor([0.3]), requires_grad=True))
p = torch.cat([p, 1 - p])
sigma = pyro.param("sigma", Variable(torch.Tensor([1.0]), requires_grad=True))
mus = Variable(torch.Tensor([-1, 1]))
with pyro.iarange("data", len(data)) as batch:
n = len(batch)
z = pyro.sample("z", dist.Categorical(p.unsqueeze(0).expand(n, 2)))
assert z.size() == (n, 2)
mu = torch.mv(z, mus)
pyro.observe("x", dist.Normal(mu, sigma.expand(n)), data[batch])
评论列表
文章目录