test_slinalg.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_solve_correctness(self):
        if not imported_scipy:
            raise SkipTest("Scipy needed for the Cholesky and Solve ops.")
        rng = numpy.random.RandomState(utt.fetch_seed())
        A = theano.tensor.matrix()
        b = theano.tensor.matrix()
        y = self.op(A, b)
        gen_solve_func = theano.function([A, b], y)

        cholesky_lower = Cholesky(lower=True)
        L = cholesky_lower(A)
        y_lower = self.op(L, b)
        lower_solve_func = theano.function([L, b], y_lower)

        cholesky_upper = Cholesky(lower=False)
        U = cholesky_upper(A)
        y_upper = self.op(U, b)
        upper_solve_func = theano.function([U, b], y_upper)

        b_val = numpy.asarray(rng.rand(5, 1), dtype=config.floatX)

        # 1-test general case
        A_val = numpy.asarray(rng.rand(5, 5), dtype=config.floatX)
        # positive definite matrix:
        A_val = numpy.dot(A_val.transpose(), A_val)
        assert numpy.allclose(scipy.linalg.solve(A_val, b_val),
                              gen_solve_func(A_val, b_val))

        # 2-test lower traingular case
        L_val = scipy.linalg.cholesky(A_val, lower=True)
        assert numpy.allclose(scipy.linalg.solve_triangular(L_val, b_val, lower=True),
                              lower_solve_func(L_val, b_val))

        # 3-test upper traingular case
        U_val = scipy.linalg.cholesky(A_val, lower=False)
        assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
                              upper_solve_func(U_val, b_val))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号