test_svd.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:dask-ml 作者: dask 项目源码 文件源码
def test_basic(algorithm):
    a = dd.TruncatedSVD(random_state=0, algorithm=algorithm)
    b = sd.TruncatedSVD(random_state=0)
    b.fit(Xdense)
    a.fit(dXdense)

    np.testing.assert_allclose(a.components_, b.components_, atol=1e-3)
    assert_estimator_equal(a, b, exclude=['components_',
                                          'explained_variance_'],
                           atol=1e-3)
    assert a.explained_variance_.shape == b.explained_variance_.shape
    np.testing.assert_allclose(a.explained_variance_,
                               b.explained_variance_,
                               rtol=0.01)

# The rest come straight from scikit-learn, with dask arrays substituted
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号