def test_local_useless_split():
x = tensor.matrix('x')
splits = tensor.ivector('splits')
opt = tensor.split(x, splits, n_splits=1)
nonopt = tensor.split(x, splits, n_splits=3)
mode = compile.get_default_mode().including("local_useless_split")
f_opt = theano.function([x, splits], opt, mode=mode)
f_nonopt = theano.function([x, splits], nonopt, mode=mode)
f_opt(numpy.random.rand(4,4).astype(config.floatX), [4])
f_nonopt(numpy.random.rand(4,4).astype(config.floatX), [1,2,1])
graph_opt = f_opt.maker.fgraph.toposort()
graph_nonopt = f_nonopt.maker.fgraph.toposort()
assert isinstance(graph_opt[-1].op, DeepCopyOp)
assert len(graph_nonopt)==1
assert isinstance(graph_nonopt[0].op, tensor.Split)
assert check_stack_trace(f_opt, ops_to_check=[Assert])
assert check_stack_trace(f_nonopt, ops_to_check='all')
评论列表
文章目录