def test_against_numpy_tensordot(self):
""" Test against numpy.tensordot in 2D case """
stream = tuple(np.random.random((8, 8)) for _ in range(2))
for axis in (0, 1, 2):
with self.subTest('axis = {}'.format(axis)):
from_numpy = np.tensordot(*stream)
from_stream = last(itensordot(stream))
self.assertSequenceEqual(from_numpy.shape, from_stream.shape)
self.assertTrue(np.allclose(from_numpy, from_stream))
评论列表
文章目录