def test_shapes(self):
dims = [
((1, 1), (2, 1, 1)), # broadcast first argument
((2, 1, 1), (1, 1)), # broadcast second argument
((2, 1, 1), (2, 1, 1)), # matrix stack sizes match
]
for dt, (dm1, dm2) in itertools.product(self.types, dims):
a = np.ones(dm1, dtype=dt)
b = np.ones(dm2, dtype=dt)
res = self.matmul(a, b)
assert_(res.shape == (2, 1, 1))
# vector vector returns scalars.
for dt in self.types:
a = np.ones((2,), dtype=dt)
b = np.ones((2,), dtype=dt)
c = self.matmul(a, b)
assert_(np.array(c).shape == ())
评论列表
文章目录