def test_scalar_axes(self):
# Test matrix-matrix
amat = fmatrix()
bmat = dmatrix()
# We let at float64 to test mix of float32 and float64.
axes = 1
aval = rand(4, 5).astype('float32')
bval = rand(5, 3)
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat, bmat], c)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval)))
utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test tensor-tensor
amat = tensor3()
bmat = tensor3()
axes = 2
aval = rand(3, 4, 5)
bval = rand(4, 5, 3)
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat, bmat], c)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval)))
utt.verify_grad(self.TensorDot(axes), [aval, bval])
评论列表
文章目录