test_scan.py 文件源码

python
阅读 47 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_only_nonseq_inputs(self):
        # Compile the Theano function
        n_steps = 2
        inp = tensor.matrix()
        broadcasted_inp, _ = theano.scan(lambda x: x,
                                         non_sequences=[inp],
                                         n_steps=n_steps)
        out = broadcasted_inp.sum()
        gr = tensor.grad(out, inp)
        fun = theano.function([inp], [broadcasted_inp, gr])

        # Execute the Theano function and compare outputs to the expected outputs
        inputs = numpy.array([[1, 2], [3, 4]], dtype=theano.config.floatX)
        expected_out1 = numpy.repeat(inputs[None], n_steps, axis=0)
        expected_out2 = numpy.ones(inputs.shape, dtype="int8") * n_steps

        out1, out2 = fun(inputs)
        utt.assert_allclose(out1, expected_out1)
        utt.assert_allclose(out2, expected_out2)

    # simple rnn, one input, one state, weights for each; input/state
    # are vectors, weights are scalars
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号