test_scan.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_inplace3(self):
        rng = numpy.random.RandomState(utt.fetch_seed())

        vx0 = asarrayX(rng.uniform())
        vx1 = asarrayX(rng.uniform())
        x0 = theano.shared(vx0)
        x1 = theano.shared(vx1)
        outputs, updates = theano.scan(lambda x, y: (x + asarrayX(1),
                                                     y + asarrayX(1)),
                                       [],
                                       [x0, x1],
                                       n_steps=3)
        x0 = asarrayX(numpy.zeros((3,)))
        x0[0] = vx0
        x0 = theano.tensor.constant(x0)
        to_replace = outputs[0].owner.inputs[0].owner.inputs[1]
        outputs = theano.clone(outputs,
                               replace=[(to_replace, x0)])
        mode = theano.compile.mode.get_mode(None).including('inplace')
        f9 = theano.function([],
                             outputs,
                             updates=updates,
                             mode=mode)
        scan_node = [x for x in f9.maker.fgraph.toposort()
                     if isinstance(x.op, theano.scan_module.scan_op.Scan)]
        assert 0 not in scan_node[0].op.destroy_map.keys()
        assert 1 in scan_node[0].op.destroy_map.keys()

    # Shared variable with updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号