def test_basic(algorithm):
a = dd.TruncatedSVD(random_state=0, algorithm=algorithm)
b = sd.TruncatedSVD(random_state=0)
b.fit(Xdense)
a.fit(dXdense)
np.testing.assert_allclose(a.components_, b.components_, atol=1e-3)
assert_estimator_equal(a, b, exclude=['components_',
'explained_variance_'],
atol=1e-3)
assert a.explained_variance_.shape == b.explained_variance_.shape
np.testing.assert_allclose(a.explained_variance_,
b.explained_variance_,
rtol=0.01)
# The rest come straight from scikit-learn, with dask arrays substituted
评论列表
文章目录