test_blas.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def cmp_dot22(self, b_shp, c_shp):
        av = numpy.zeros((0, 0), dtype=self.dtype)
        bv = self.rand(*b_shp)
        cv = self.rand(*c_shp)

        a = self.shared(av, 'a')
        b = self.shared(bv, 'b')
        c = self.shared(cv, 'c')

        b_t = self.shared(bv.T, 'b.T')
        c_t = self.shared(cv.T, 'c.T')

        b_dev = b.get_value(borrow=False, return_internal_type=True)
        c_dev = c.get_value(borrow=False, return_internal_type=True)
        bt_dev = b_t.get_value(borrow=False, return_internal_type=True)
        ct_dev = c_t.get_value(borrow=False, return_internal_type=True)

        f_nn = theano.function([], [], updates=[(a, tensor.dot(b, c))],
                mode=self.mode)
        # print 'class name:', self.__class__.__name__
        # theano.printing.debugprint(f_nn)
        f_nt = theano.function([], [], updates=[(a, tensor.dot(b, c_t.T))],
                mode=self.mode)
        f_tn = theano.function([], [], updates=[(a, tensor.dot(b_t.T, c))],
                mode=self.mode)
        f_tt = theano.function([], [], updates=[(a, tensor.dot(b_t.T, c_t.T))],
                mode=self.mode)

        # Try with all stride patterns, and all transposed pattern
        for step_signs in itertools_product((-1, 1), repeat=4):
            for step in (1, 2):
                b_step1, b_step2, c_step1, c_step2 = (s * step
                        for s in step_signs)

                b.set_value(b_dev.copy()[::b_step1, ::b_step2], borrow=True)
                c.set_value(c_dev.copy()[::c_step1, ::c_step2], borrow=True)
                b_t.set_value(bt_dev.copy()[::b_step2, ::b_step1], borrow=True)
                c_t.set_value(ct_dev.copy()[::c_step2, ::c_step1], borrow=True)

                # Numpy result
                a_n = numpy.dot(bv[::b_step1, ::b_step2],
                                cv[::c_step1, ::c_step2])

                f_nn()
                assert numpy.allclose(a.get_value(), a_n)

                f_nt()
                assert numpy.allclose(a.get_value(), a_n)

                f_tn()
                assert numpy.allclose(a.get_value(), a_n)

                f_tt()
                assert numpy.allclose(a.get_value(), a_n)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号