test_mapdata.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_nested_map_data():
    means = [Variable(torch.randn(2)) for i in range(8)]
    mean_batch_size = 2
    stds = [Variable(torch.abs(torch.randn(2))) for i in range(6)]
    std_batch_size = 3

    def model(means, stds):
        return pyro.map_data("a", means,
                             lambda i, x:
                             pyro.map_data("a_{}".format(i), stds,
                                           lambda j, y:
                                           pyro.sample("x_{}{}".format(i, j),
                                                       dist.normal, x, y),
                                           batch_size=std_batch_size),
                             batch_size=mean_batch_size)

    model = model

    xs = model(means, stds)
    assert len(xs) == mean_batch_size
    assert len(xs[0]) == std_batch_size

    tr = poutine.trace(model).get_trace(means, stds)
    for name in tr.nodes.keys():
        if tr.nodes[name]["type"] == "sample" and name.startswith("x_"):
            assert tr.nodes[name]["scale"] == 4.0 * 2.0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号