test_mapdata.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def test_batch_dim(batch_dim):

    data = Variable(torch.randn(4, 5, 7))

    def local_model(ixs, _xs):
        xs = _xs.view(-1, _xs.size(2))
        return pyro.sample("xs", dist.normal,
                           xs, Variable(torch.ones(xs.size())))

    def model():
        return pyro.map_data("md", data, local_model,
                             batch_size=1, batch_dim=batch_dim)

    tr = poutine.trace(model).get_trace()
    assert tr.nodes["xs"]["value"].size(0) == data.size(1 - batch_dim)
    assert tr.nodes["xs"]["value"].size(1) == data.size(2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号