def test_constant_folding():
""" Test that constant folding get registered at fast_compile
An error removed that registration during the registration.
"""
x = tensor.dvector()
mode = theano.compile.get_mode("FAST_COMPILE").excluding("fusion")
f = theano.function([x], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.
x = tensor.constant(3)
assert x.ndim == 0
mode = theano.compile.get_mode("FAST_COMPILE").excluding("fusion")
f = theano.function([], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert all([isinstance(n.op, DeepCopyOp) for n in topo])
评论列表
文章目录