test_basic.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_scalar_axes(self):
        # Test matrix-matrix
        amat = fmatrix()
        bmat = dmatrix()
              # We let at float64 to test mix of float32 and float64.
        axes = 1
        aval = rand(4, 5).astype('float32')
        bval = rand(5, 3)
        c = tensordot(amat, bmat, axes)
        f3 = inplace_func([amat, bmat], c)
        self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
                                       f3(aval, bval)))
        utt.verify_grad(self.TensorDot(axes), [aval, bval])

        # Test tensor-tensor
        amat = tensor3()
        bmat = tensor3()
        axes = 2
        aval = rand(3, 4, 5)
        bval = rand(4, 5, 3)
        c = tensordot(amat, bmat, axes)
        f3 = inplace_func([amat, bmat], c)
        self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
                                       f3(aval, bval)))
        utt.verify_grad(self.TensorDot(axes), [aval, bval])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号