test_scan.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_alloc_inputs2(self):
        raise SkipTest("This tests depends on an optimization for "
                       "scan that has not been implemented yet.")
        W1 = tensor.matrix()
        W2 = tensor.matrix()
        h0 = tensor.vector()

        def lambda_fn(W1, h, W2):
            return W1 * tensor.dot(h, W2)

        o, _ = theano.scan(lambda_fn,
                           sequences=tensor.zeros_like(W1),
                           outputs_info=h0,
                           non_sequences=[tensor.zeros_like(W2)],
                           n_steps=5)

        f = theano.function([h0, W1, W2], o, mode=mode_with_opt)
        scan_node = [x for x in f.maker.fgraph.toposort()
                     if isinstance(x.op,
                                   theano.scan_module.scan_op.Scan)][0]

        assert len([x for x in scan_node.op.fn.maker.fgraph.toposort()
                    if isinstance(x.op, theano.tensor.Elemwise)]) == 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号