test_opt.py 文件源码

python
阅读 83 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_local_useless_split():
    x = tensor.matrix('x')
    splits = tensor.ivector('splits')
    opt = tensor.split(x, splits, n_splits=1)
    nonopt = tensor.split(x, splits, n_splits=3)

    mode = compile.get_default_mode().including("local_useless_split")
    f_opt = theano.function([x, splits], opt, mode=mode)
    f_nonopt = theano.function([x, splits], nonopt, mode=mode)

    f_opt(numpy.random.rand(4,4).astype(config.floatX), [4])
    f_nonopt(numpy.random.rand(4,4).astype(config.floatX), [1,2,1])
    graph_opt = f_opt.maker.fgraph.toposort()
    graph_nonopt = f_nonopt.maker.fgraph.toposort()

    assert isinstance(graph_opt[-1].op, DeepCopyOp)
    assert len(graph_nonopt)==1
    assert isinstance(graph_nonopt[0].op, tensor.Split)

    assert check_stack_trace(f_opt, ops_to_check=[Assert])
    assert check_stack_trace(f_nonopt, ops_to_check='all')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号