test_scan.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_alloc_inputs1(self):
        W1 = tensor.matrix('W1')
        W2 = tensor.matrix('W2')
        h0 = tensor.vector('h0')

        def lambda_fn(h, W1, W2):
            return tensor.dot(h, W1 * W2)
        o, _ = theano.scan(lambda_fn,
                           outputs_info=h0,
                           non_sequences=[W1, tensor.zeros_like(W2)],
                           n_steps=5)

        f = theano.function([h0, W1, W2], o, mode=mode_with_opt)
        scan_node = [x for x in f.maker.fgraph.toposort()
                     if isinstance(x.op,
                                   theano.scan_module.scan_op.Scan)][0]
        assert len([x for x in scan_node.op.fn.maker.fgraph.toposort()
                    if isinstance(x.op, theano.tensor.Elemwise)]) == 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号