def test_lmatvec(b0, b1, quad, format, axis, k0, k1):
"""Test matrix-vector product"""
global c, c1, d, d1
b0 = b0(N, quad=quad)
b1 = b1(N, quad=quad)
mat = shenfun.spectralbase.inner_product((b0, k0), (b1, k1))
c = mat.matvec(a, c, format='csr')
c1 = mat.matvec(a, c1, format=format)
assert np.allclose(c, c1)
d.fill(0)
d1.fill(0)
d = mat.matvec(b, d, format='csr', axis=axis)
d1 = mat.matvec(b, d1, format=format, axis=axis)
assert np.allclose(d, d1)
# Test multidimensional with axis equals 1D case
d1.fill(0)
bc = [np.newaxis,]*3
bc[axis] = slice(None)
fj = np.broadcast_to(a[bc], (N,)*3).copy()
d1 = mat.matvec(fj, d1, format=format, axis=axis)
cc = [0,]*3
cc[axis] = slice(None)
assert np.allclose(c, d1[cc])
评论列表
文章目录