def test_weird_valid_axes(self):
# Test matrix-matrix
amat = matrix()
bmat = matrix()
for axes in [0,
(1, 0),
[1, 0],
(1, (0, )),
((1, ), 0),
([1], [0]),
([], [])]:
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat, bmat], c)
aval = rand(4, 7)
bval = rand(7, 9)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval)))
utt.verify_grad(self.TensorDot(axes), [aval, bval])
评论列表
文章目录