test_pickle_unpickle_theano_fn.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_pickle_unpickle_with_reoptimization():
    mode = theano.config.mode
    if mode in ["DEBUG_MODE", "DebugMode"]:
        mode = "FAST_RUN"
    x1 = T.fmatrix('x1')
    x2 = T.fmatrix('x2')
    x3 = theano.shared(numpy.ones((10, 10), dtype=floatX))
    x4 = theano.shared(numpy.ones((10, 10), dtype=floatX))
    y = T.sum(T.sum(T.sum(x1 ** 2 + x2) + x3) + x4)

    updates = OrderedDict()
    updates[x3] = x3 + 1
    updates[x4] = x4 + 1
    f = theano.function([x1, x2], y, updates=updates, mode=mode)

    # now pickle the compiled theano fn
    string_pkl = pickle.dumps(f, -1)

    in1 = numpy.ones((10, 10), dtype=floatX)
    in2 = numpy.ones((10, 10), dtype=floatX)

    # test unpickle with optimization
    default = theano.config.reoptimize_unpickled_function
    try:
        # the default is True
        theano.config.reoptimize_unpickled_function = True
        f_ = pickle.loads(string_pkl)
        assert f(in1, in2) == f_(in1, in2)
    finally:
        theano.config.reoptimize_unpickled_function = default
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号