def test_scan_extra_inputs_hessian(self):
x = theano.tensor.vector('x')
A = theano.tensor.matrix('A')
fc1 = theano.shared(0.5, name='fc1')
fc2 = theano.shared(0.9, name='fc2')
y = fc1 * theano.dot(x * x, theano.dot(A, x))
y.name = 'y'
gy = theano.tensor.grad(y, x)
gy.name = 'gy'
hy, updates = theano.scan(
lambda i, gy, x: theano.tensor.grad(gy[i] * fc2, x),
sequences=theano.tensor.arange(gy.shape[0]),
non_sequences=[gy, x])
f = theano.function([x, A], hy, allow_input_downcast=True)
vx = numpy.array([1., 1.], dtype=theano.config.floatX)
vA = numpy.array([[1., 1.], [1., 0.]], dtype=theano.config.floatX)
vR = numpy.array([[3.6, 1.8], [1.8, 0.9]], dtype=theano.config.floatX)
out = f(vx, vA)
utt.assert_allclose(out, vR)
评论列表
文章目录