test_opt.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_local_merge_alloc():
    # Add this opt to the default mode,
    # otherwise, FAST_COMPILE fails.
    default_mode = theano.compile.mode.get_default_mode()
    opt_mode = default_mode.including("local_merge_alloc")

    x = T.iscalar('x')
    y = T.iscalar('y')
    y2 = T.iscalar('y2')
    z = T.iscalar('z')
    w = T.iscalar('w')
    m = T.fscalar('m')
    # case 1
    # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
    output = T.alloc(T.alloc(m, 1, y, 1, 1), x, y, z, w)
    f = theano.function([m, x, y, z, w], output, mode=opt_mode)
    topo = f.maker.fgraph.toposort()
    assert len(topo) == 1
    assert isinstance(topo[0].op, T.Alloc)
    o = f(0., 1, 2, 3, 4)
    assert o.shape == (1, 2, 3, 4)

    # case 2
    # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
    output = T.alloc(T.alloc(m, y, 1, 1), x, y, z, w)
    f = theano.function([m, x, y, z, w], output, mode=opt_mode)
    topo = f.maker.fgraph.toposort()
    assert len(topo) == 1
    assert isinstance(topo[0].op, T.Alloc)
    o = f(0., 1, 2, 3, 4)
    assert o.shape == (1, 2, 3, 4)

    # case 3
    # Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) ->
    #   Alloc(m, x, assert(y1, y1==y2), z, w)
    output = T.alloc(T.alloc(m, y, 1, 1), x, y2, z, w)
    f = theano.function([m, x, y, y2, z, w], output, mode=opt_mode)
    topo = f.maker.fgraph.toposort()
    assert len(topo) == 3
    assert isinstance(topo[-2].op, T.opt.Assert)
    assert isinstance(topo[-1].op, T.Alloc)
    o = f(0., 1, 2, 2, 3, 4)
    assert o.shape == (1, 2, 3, 4)
    assert_raises((AssertionError, ValueError), f, 0., 1, 2, 5, 3, 4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号