test_algorithms.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:thunder-factorization 作者: thunder-project 项目源码 文件源码
def test_svd(eng):
    x = make_low_rank_matrix(n_samples=10, n_features=5, random_state=0)
    x = fromarray(x, engine=eng)

    from sklearn.utils.extmath import randomized_svd
    u1, s1, v1 = randomized_svd(x.toarray(), n_components=2,  random_state=0)

    u2, s2, v2 = SVD(k=2, method='direct').fit(x)
    assert allclose_sign(u1, u2)
    assert allclose(s1, s2)
    assert allclose_sign(v1.T, v2.T)

    u2, s2, v2 = SVD(k=2, method='em', max_iter=100, seed=0).fit(x)
    tol = 1e-1
    assert allclose_sign(u1, u2, atol=tol)
    assert allclose(s1, s2, atol=tol)
    assert allclose_sign(v1.T, v2.T, atol=tol)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号