def test_perform(self):
if not imported_scipy:
raise SkipTest('kron tests need the scipy package to be installed')
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
x = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp0))
a = numpy.asarray(self.rng.rand(*shp0)).astype(config.floatX)
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2:
continue
y = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp1))
f = function([x, y], kron(x, y))
b = self.rng.rand(*shp1).astype(config.floatX)
out = f(a, b)
# Newer versions of scipy want 4 dimensions at least,
# so we have to add a dimension to a and flatten the result.
if len(shp0) + len(shp1) == 3:
scipy_val = scipy.linalg.kron(
a[numpy.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
utt.assert_allclose(out, scipy_val)
评论列表
文章目录