test_blas.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_gemm_opt_double_gemm():
    """This is the pattern that shows up in the autoencoder"""
    X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
    R, S, c = T.matrix(), T.matrix(), T.scalar()

    just_gemm([X, Y, Z, a, b, R, S, c],
              [Z * c + a * T.dot(X, Y) + b * T.dot(R, S).T],
              ishapes=[(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()],
              expected_nb_gemm=2)

    ishapes = [(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()]
    i = [X, Y, Z, a, b, R, S, c]
    o = [(a * T.dot(X, Y)
        + gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype(config.floatX)))]
    try:
        f = inplace_func([In(ii, mutable=True) for ii in i], o,
                mode='FAST_RUN', on_unused_input='ignore')
        for node in f.maker.fgraph.apply_nodes:
            if isinstance(node.op, T.Dot):
                raise Failure('dot in graph')
            if node.op == _dot22:
                raise Failure('_dot22 in graph')
        g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
                on_unused_input='ignore')
        # for node in g.maker.fgraph.apply_nodes:
        #    if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')

        rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
        r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
             for sh in ishapes])
        rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
        r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
             for sh in ishapes])
        max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
        eps = 1.0e-8
        if config.floatX == 'float32':
            eps = 1.0e-6
        if  max_abs_err > eps:
            raise Failure(
                'GEMM is computing the wrong output. max_rel_err =',
                max_abs_err)
    except Failure:
        for node in f.maker.fgraph.toposort():
            print('GRAPH', node)
        raise
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号