test_blas.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_maxpool():
        """TODO: test the gpu version!!! """
        for d0, d1, r_true, r_false in [(4, 4, [[[[5, 7], [13, 15]]]], [[[[5, 7], [13, 15]]]]),
                                        (5, 5, [[[[6, 8], [16, 18], [21, 23]]]],
                                         [[[[6, 8, 9], [16, 18, 19], [21, 23, 24]]]])]:
            for border, ret in [(True, r_true), (False, r_false)]:
                ret = numpy.array(ret)
                a = tcn.blas.Pool((2, 2), border)
                dmatrix4 = tensor.TensorType("float32", (False, False, False, False))
                b = dmatrix4()
                f = pfunc([b], [a(b)], mode=mode_with_gpu)

                bval = numpy.arange(0, d0 * d1).reshape(1, 1, d0, d1)
                r = f(bval)[0]
    #            print bval, bval.shape, border
                # print r, r.shape
                assert (ret == r).all()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号