test_matrices.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:shenfun 作者: spectralDNS 项目源码 文件源码
def test_lmatvec(b0, b1, quad, format, axis, k0, k1):
    """Test matrix-vector product"""
    global c, c1, d, d1
    b0 = b0(N, quad=quad)
    b1 = b1(N, quad=quad)
    mat = shenfun.spectralbase.inner_product((b0, k0), (b1, k1))
    c = mat.matvec(a, c, format='csr')
    c1 = mat.matvec(a, c1, format=format)
    assert np.allclose(c, c1)

    d.fill(0)
    d1.fill(0)
    d = mat.matvec(b, d, format='csr', axis=axis)
    d1 = mat.matvec(b, d1, format=format, axis=axis)
    assert np.allclose(d, d1)

    # Test multidimensional with axis equals 1D case
    d1.fill(0)
    bc = [np.newaxis,]*3
    bc[axis] = slice(None)
    fj = np.broadcast_to(a[bc], (N,)*3).copy()
    d1 = mat.matvec(fj, d1, format=format, axis=axis)
    cc = [0,]*3
    cc[axis] = slice(None)
    assert np.allclose(c, d1[cc])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号