def test_eliminate_nonseqs(self):
W = tensor.scalar('W')
sh = theano.shared(asarrayX(2.))
x1 = tensor.vector('x1')
x2 = tensor.scalar('x2')
def rec_fn(*args):
w = args[-1]
return [(w + 1., # mitsot
w + 2., # sitsot
w + 3.), # nitsot
{sh: w + 4.}] # shared
[X1, X2, X3], updates = theano.scan(
rec_fn,
[],
[dict(initial=x1, taps=[-1, -3]), x2, None],
W,
n_steps=5,
truncate_gradient=-1,
go_backwards=False)
f = theano.function([W, x1, x2], [X1, X2, X3],
updates=updates,
mode=theano.Mode(linker='py'),
allow_input_downcast=True)
rng = numpy.random.RandomState(utt.fetch_seed())
v_w = asarrayX(rng.uniform())
outs = f(v_w, [0, 0, 0], 0)
utt.assert_allclose(outs[0], v_w + 1)
utt.assert_allclose(outs[1], v_w + 2)
utt.assert_allclose(outs[2], v_w + 3)
utt.assert_allclose(sh.get_value(), v_w + 4)
评论列表
文章目录