def gmm_batch_guide(data):
with pyro.iarange("data", len(data)) as batch:
n = len(batch)
ps = pyro.param("ps", Variable(torch.ones(n, 1) * 0.6, requires_grad=True))
ps = torch.cat([ps, 1 - ps], dim=1)
z = pyro.sample("z", dist.Categorical(ps))
assert z.size() == (n, 2)
评论列表
文章目录