test_printing.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_printing_scan():
    # Skip test if pydot is not available.
    if not theano.printing.pydot_imported:
        raise SkipTest('pydot not available')

    def f_pow2(x_tm1):
        return 2 * x_tm1

    state = theano.tensor.scalar('state')
    n_steps = theano.tensor.iscalar('nsteps')
    output, updates = theano.scan(f_pow2,
                                  [],
                                  state,
                                  [],
                                  n_steps=n_steps,
                                  truncate_gradient=-1,
                                  go_backwards=False)
    f = theano.function([state, n_steps],
                        output,
                        updates=updates,
                        allow_input_downcast=True)
    theano.printing.pydotprint(output, scan_graphs=True)
    theano.printing.pydotprint(f, scan_graphs=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号