test_blas.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_dot22():
    for dtype1 in ['float32', 'float64', 'complex64', 'complex128']:
        a = T.matrix(dtype=dtype1)
        for dtype2 in ['float32', 'float64', 'complex64', 'complex128']:
            b = T.matrix(dtype=dtype2)
            f = theano.function([a, b], T.dot(a, b), mode=mode_blas_opt)
            topo = f.maker.fgraph.toposort()
            if dtype1 == dtype2:
                assert _dot22 in [x.op for x in topo], (dtype1, dtype2)
            else:
                check = [isinstance(x.op, T.Dot) for x in topo]
                assert any(check), (dtype1, dtype2)
            rng = numpy.random.RandomState(unittest_tools.fetch_seed())

            def cmp(a_shp, b_shp):
                av = rng.uniform(size=a_shp).astype(dtype1)
                bv = rng.uniform(size=b_shp).astype(dtype2)
                f(av, bv)

            cmp((3, 4), (4, 5))
            cmp((0, 4), (4, 5))
            cmp((3, 0), (0, 5))
            cmp((3, 4), (4, 0))
            cmp((0, 4), (4, 0))
            cmp((0, 0), (0, 0))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号