def test_printing_scan():
# Skip test if pydot is not available.
if not theano.printing.pydot_imported:
raise SkipTest('pydot not available')
def f_pow2(x_tm1):
return 2 * x_tm1
state = theano.tensor.scalar('state')
n_steps = theano.tensor.iscalar('nsteps')
output, updates = theano.scan(f_pow2,
[],
state,
[],
n_steps=n_steps,
truncate_gradient=-1,
go_backwards=False)
f = theano.function([state, n_steps],
output,
updates=updates,
allow_input_downcast=True)
theano.printing.pydotprint(output, scan_graphs=True)
theano.printing.pydotprint(f, scan_graphs=True)
评论列表
文章目录