test_basic.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_weird_valid_axes(self):
        # Test matrix-matrix
        amat = matrix()
        bmat = matrix()
        for axes in [0,
                     (1, 0),
                     [1, 0],
                     (1, (0, )),
                     ((1, ), 0),
                     ([1], [0]),
                     ([], [])]:
            c = tensordot(amat, bmat, axes)
            f3 = inplace_func([amat, bmat], c)
            aval = rand(4, 7)
            bval = rand(7, 9)
            self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
                                           f3(aval, bval)))
            utt.verify_grad(self.TensorDot(axes), [aval, bval])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号