def test_shared_input_output():
# Test bug reported on the mailing list by Alberto Orlandi
# https://groups.google.com/d/topic/theano-users/6dLaEqc2R6g/discussion
# The shared variable is both an input and an output of the function.
inc = theano.tensor.iscalar('inc')
state = theano.shared(0)
state.name = 'state'
linker = theano.gof.CLinker()
mode = theano.Mode(linker=linker)
f = theano.function([inc], state, updates=[(state, state + inc)],
mode=mode)
g = theano.function([inc], state, updates=[(state, state + inc)])
# Initial value
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 0, (f0, g0)
# Increment state via f, returns the previous value.
f2 = f(2)
assert f2 == f0, (f2, f0)
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 2, (f0, g0)
# Increment state via g, returns the previous value
g3 = g(3)
assert g3 == g0, (g3, g0)
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 5, (f0, g0)
vstate = theano.shared(numpy.zeros(3, dtype='int32'))
vstate.name = 'vstate'
fv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)],
mode=mode)
gv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)])
# Initial value
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 0), fv0
assert numpy.all(gv0 == 0), gv0
# Increment state via f, returns the previous value.
fv2 = fv(2)
assert numpy.all(fv2 == fv0), (fv2, fv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 2), fv0
assert numpy.all(gv0 == 2), gv0
# Increment state via g, returns the previous value
gv3 = gv(3)
assert numpy.all(gv3 == gv0), (gv3, gv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 5), fv0
assert numpy.all(gv0 == 5), gv0
评论列表
文章目录