def test_local_useless_alloc():
useless_alloc = out2in(local_useless_alloc)
merge_alloc = out2in(local_merge_alloc)
x = T.iscalar('x')
y = T.iscalar('y')
y2 = T.iscalar('y2')
z = T.iscalar('z')
w = T.iscalar('w')
m = T.fscalar('m')
# case 1
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
output = T.alloc(T.alloc(m, 1, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
topo = g.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Alloc)
# case 2
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
output = T.alloc(T.alloc(m, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
topo = g.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Alloc)
# case 3
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) ->
# Alloc(m, x, assert(y1, y1==y2), z, w)
output = T.alloc(T.alloc(m, y, 1, 1), x, y2, z, w)
g = FunctionGraph([m, x, y, y2, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
topo = g.toposort()
assert len(topo) == 3
assert isinstance(topo[-2].op, T.opt.Assert)
assert isinstance(topo[-1].op, T.Alloc)
评论列表
文章目录