test_enum.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def gmm_batch_guide(data):
    with pyro.iarange("data", len(data)) as batch:
        n = len(batch)
        ps = pyro.param("ps", Variable(torch.ones(n, 1) * 0.6, requires_grad=True))
        ps = torch.cat([ps, 1 - ps], dim=1)
        z = pyro.sample("z", dist.Categorical(ps))
        assert z.size() == (n, 2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号