test_blas.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
              max_graphlen=0, expected_nb_gemm=1):
    try:
        f = inplace_func(
                [In(ii, mutable=True, allow_downcast=True) for ii in i],
                o,
                mode='FAST_RUN',
                on_unused_input='ignore')
        nb_gemm = 0
        for node in f.maker.fgraph.apply_nodes:
            if isinstance(node.op, T.Dot):
                raise Failure('dot not changed to gemm_inplace in graph')
            if node.op == _dot22:
                raise Failure('_dot22 not changed to gemm_inplace in graph')
            if node.op == gemm_inplace:
                nb_gemm += 1
        assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
        g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
                allow_input_downcast=True, on_unused_input='ignore')
        for node in g.maker.fgraph.apply_nodes:
            if node.op == gemm_inplace:
                raise Exception('gemm_inplace in original graph')

        graphlen = len(f.maker.fgraph.toposort())
        if max_graphlen and (graphlen <= max_graphlen):
            # theano.printing.debugprint(f)
            assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen)

        rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
        r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
                 for sh in ishapes])
        rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
        r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
                 for sh in ishapes])
        max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
        eps = 1.0e-8
        if config.floatX == 'float32':
            eps = 1.0e-6
        if  max_abs_err > eps:
            raise Failure('GEMM is computing the wrong output. max_rel_err =',
                          max_abs_err)
    except Failure:
        for node in f.maker.fgraph.toposort():
            print('GRAPH', node)
        raise
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号