def test_local_merge_alloc():
# Add this opt to the default mode,
# otherwise, FAST_COMPILE fails.
default_mode = theano.compile.mode.get_default_mode()
opt_mode = default_mode.including("local_merge_alloc")
x = T.iscalar('x')
y = T.iscalar('y')
y2 = T.iscalar('y2')
z = T.iscalar('z')
w = T.iscalar('w')
m = T.fscalar('m')
# case 1
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
output = T.alloc(T.alloc(m, 1, y, 1, 1), x, y, z, w)
f = theano.function([m, x, y, z, w], output, mode=opt_mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Alloc)
o = f(0., 1, 2, 3, 4)
assert o.shape == (1, 2, 3, 4)
# case 2
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
output = T.alloc(T.alloc(m, y, 1, 1), x, y, z, w)
f = theano.function([m, x, y, z, w], output, mode=opt_mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Alloc)
o = f(0., 1, 2, 3, 4)
assert o.shape == (1, 2, 3, 4)
# case 3
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) ->
# Alloc(m, x, assert(y1, y1==y2), z, w)
output = T.alloc(T.alloc(m, y, 1, 1), x, y2, z, w)
f = theano.function([m, x, y, y2, z, w], output, mode=opt_mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert isinstance(topo[-2].op, T.opt.Assert)
assert isinstance(topo[-1].op, T.Alloc)
o = f(0., 1, 2, 2, 3, 4)
assert o.shape == (1, 2, 3, 4)
assert_raises((AssertionError, ValueError), f, 0., 1, 2, 5, 3, 4)
评论列表
文章目录