def test_recursive_lift(self):
v = T.vector(dtype="float64")
m = T.matrix(dtype="float64")
out = ((v + 42) * (m + 84)).T
g = FunctionGraph([v, m], [out])
init_str_g = ("[InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
"(<TensorType(float64, vector)>, "
"InplaceDimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, matrix)>, "
"InplaceDimShuffle{x,x}(TensorConstant{84}))))]")
self.assertTrue(str(g) == init_str_g)
new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
opt_str_g = ("[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{0,x}(<TensorType(float64, vector)>), "
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
"(<TensorType(float64, matrix)>), "
"InplaceDimShuffle{x,x}(TensorConstant{84})))]")
self.assertTrue(str(new_g) == opt_str_g)
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(check_stack_trace(new_g, ops_to_check='all'))
评论列表
文章目录