test_basic.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_csr_correct_output_faster_than_scipy(self):

        # contrast with test_grad, we put csr in float32, csc in float64

        sparse_dtype = 'float32'
        dense_dtype = 'float32'

        a = SparseType('csr', dtype=sparse_dtype)()
        b = tensor.matrix(dtype=dense_dtype)
        d = theano.dot(a, b)
        f = theano.function([a, b], d)

        for M, N, K, nnz in [(4, 3, 2, 3),
                             (40, 30, 20, 3),
                             (40, 30, 20, 30),
                             (400, 3000, 200, 6000),
                         ]:
            spmat = sp.csr_matrix(random_lil((M, N), sparse_dtype, nnz))
            mat = numpy.asarray(numpy.random.randn(N, K), dense_dtype)
            t0 = time.time()
            theano_result = f(spmat, mat)
            t1 = time.time()
            scipy_result = spmat * mat
            t2 = time.time()

            theano_time = t1 - t0
            scipy_time = t2 - t1
            # print 'theano took', theano_time,
            # print 'scipy took', scipy_time
            overhead_tol = 0.002  # seconds
            overhead_rtol = 1.1  # times as long
            utt.assert_allclose(scipy_result, theano_result)
            if (not theano.config.mode in ["DebugMode", "DEBUG_MODE"] and
                theano.config.cxx):
                    self.assertFalse(
                        theano_time > overhead_rtol * scipy_time + overhead_tol,
                        (theano_time,
                         overhead_rtol * scipy_time + overhead_tol,
                         scipy_time, overhead_rtol, overhead_tol))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号