def test_solve_correctness(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Cholesky and Solve ops.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
b = theano.tensor.matrix()
y = self.op(A, b)
gen_solve_func = theano.function([A, b], y)
cholesky_lower = Cholesky(lower=True)
L = cholesky_lower(A)
y_lower = self.op(L, b)
lower_solve_func = theano.function([L, b], y_lower)
cholesky_upper = Cholesky(lower=False)
U = cholesky_upper(A)
y_upper = self.op(U, b)
upper_solve_func = theano.function([U, b], y_upper)
b_val = numpy.asarray(rng.rand(5, 1), dtype=config.floatX)
# 1-test general case
A_val = numpy.asarray(rng.rand(5, 5), dtype=config.floatX)
# positive definite matrix:
A_val = numpy.dot(A_val.transpose(), A_val)
assert numpy.allclose(scipy.linalg.solve(A_val, b_val),
gen_solve_func(A_val, b_val))
# 2-test lower traingular case
L_val = scipy.linalg.cholesky(A_val, lower=True)
assert numpy.allclose(scipy.linalg.solve_triangular(L_val, b_val, lower=True),
lower_solve_func(L_val, b_val))
# 3-test upper traingular case
U_val = scipy.linalg.cholesky(A_val, lower=False)
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val))
评论列表
文章目录