test_blas.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_factorised_scalar(self):
        a = T.matrix()
        b = T.matrix()
        c = T.matrix()
        s = theano.shared(numpy.zeros((5, 5)).astype(config.floatX))

        lr1 = T.constant(0.01).astype(config.floatX)
        lr2 = T.constant(2).astype(config.floatX)
        l2_reg = T.constant(0.0001).astype(config.floatX)

        # test constant merge with gemm
        f = theano.function([a, b], updates=[(s, lr1 * T.dot(a, b) +
                                                l2_reg * lr2 * s)],
                            mode=mode_not_fast_compile).maker.fgraph.toposort()
        #[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01,
        # <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
        # 2e-06)]
        assert len(f) == 1
        assert f[0].op == gemm_inplace

        # test factored scalar with merge
        f = theano.function([a, b], updates=[(s, lr1 * (T.dot(a, b) -
                                                        l2_reg * s))],
                            mode=mode_not_fast_compile).maker.fgraph.toposort()
        #[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01,
        # <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
        # -2e-06)]
        assert len(f) == 1
        assert f[0].op == gemm_inplace

        # test factored scalar with merge and neg
        f = theano.function([a, b],
                            updates=[(s, s - lr1 * (s * .0002 + T.dot(a, b)))],
                            mode=mode_not_fast_compile).maker.fgraph.toposort()
        #[Gemm{inplace}(<TensorType(float64, matrix)>, -0.01,
        # <TensorType(float64, matrix)>, <TensorType(float64, matrix)>,
        # 0.999998)]
        assert len(f) == 1
        assert f[0].op == gemm_inplace
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号