test_cc.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_shared_input_output():
    # Test bug reported on the mailing list by Alberto Orlandi
    # https://groups.google.com/d/topic/theano-users/6dLaEqc2R6g/discussion
    # The shared variable is both an input and an output of the function.
    inc = theano.tensor.iscalar('inc')
    state = theano.shared(0)
    state.name = 'state'
    linker = theano.gof.CLinker()
    mode = theano.Mode(linker=linker)
    f = theano.function([inc], state, updates=[(state, state + inc)],
                        mode=mode)
    g = theano.function([inc], state, updates=[(state, state + inc)])

    # Initial value
    f0 = f(0)
    g0 = g(0)
    assert f0 == g0 == 0, (f0, g0)

    # Increment state via f, returns the previous value.
    f2 = f(2)
    assert f2 == f0, (f2, f0)
    f0 = f(0)
    g0 = g(0)
    assert f0 == g0 == 2, (f0, g0)

    # Increment state via g, returns the previous value
    g3 = g(3)
    assert g3 == g0, (g3, g0)
    f0 = f(0)
    g0 = g(0)
    assert f0 == g0 == 5, (f0, g0)

    vstate = theano.shared(numpy.zeros(3, dtype='int32'))
    vstate.name = 'vstate'
    fv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)],
                         mode=mode)
    gv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)])

    # Initial value
    fv0 = fv(0)
    gv0 = gv(0)
    assert numpy.all(fv0 == 0), fv0
    assert numpy.all(gv0 == 0), gv0

    # Increment state via f, returns the previous value.
    fv2 = fv(2)
    assert numpy.all(fv2 == fv0), (fv2, fv0)
    fv0 = fv(0)
    gv0 = gv(0)
    assert numpy.all(fv0 == 2), fv0
    assert numpy.all(gv0 == 2), gv0

    # Increment state via g, returns the previous value
    gv3 = gv(3)
    assert numpy.all(gv3 == gv0), (gv3, gv0)
    fv0 = fv(0)
    gv0 = gv(0)
    assert numpy.all(fv0 == 5), fv0
    assert numpy.all(gv0 == 5), gv0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号