test_solver.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:picard 作者: pierreablin 项目源码 文件源码
def test_picard():
    N, T = 2, 10000
    rng = np.random.RandomState(42)
    S = rng.laplace(size=(N, T))
    A = rng.randn(N, N)
    X = np.dot(A, S)
    for precon in [1, 2]:
        Y, W = picard(X, precon=precon, verbose=True)
        # Get the final gradient norm
        G = np.inner(np.tanh(Y / 2.), Y) / float(T) - np.eye(N)
        assert_allclose(G, np.zeros((N, N)), atol=1e-7)
        assert_equal(Y.shape, X.shape)
        assert_equal(W.shape, A.shape)
        WA = np.dot(W, A)
        WA = get_perm(WA)[1]  # Permute and scale
        assert_allclose(WA, np.eye(N), rtol=1e-2, atol=1e-2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号