test_scan.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_scan_extra_inputs_hessian(self):
        x = theano.tensor.vector('x')
        A = theano.tensor.matrix('A')
        fc1 = theano.shared(0.5, name='fc1')
        fc2 = theano.shared(0.9, name='fc2')
        y = fc1 * theano.dot(x * x, theano.dot(A, x))
        y.name = 'y'
        gy = theano.tensor.grad(y, x)
        gy.name = 'gy'
        hy, updates = theano.scan(
            lambda i, gy, x: theano.tensor.grad(gy[i] * fc2, x),
            sequences=theano.tensor.arange(gy.shape[0]),
            non_sequences=[gy, x])

        f = theano.function([x, A], hy, allow_input_downcast=True)
        vx = numpy.array([1., 1.], dtype=theano.config.floatX)
        vA = numpy.array([[1., 1.], [1., 0.]], dtype=theano.config.floatX)
        vR = numpy.array([[3.6, 1.8], [1.8, 0.9]], dtype=theano.config.floatX)
        out = f(vx, vA)

        utt.assert_allclose(out, vR)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号