def test_transform_linear_model():
for clf in (LogisticRegression(C=0.1),
LinearSVC(C=0.01, dual=False),
SGDClassifier(alpha=0.001, n_iter=50, shuffle=True,
random_state=0)):
for thresh in (None, ".09*mean", "1e-5 * median"):
for func in (np.array, sp.csr_matrix):
X = func(data)
clf.set_params(penalty="l1")
clf.fit(X, y)
X_new = assert_warns(
DeprecationWarning, clf.transform, X, thresh)
if isinstance(clf, SGDClassifier):
assert_true(X_new.shape[1] <= X.shape[1])
else:
assert_less(X_new.shape[1], X.shape[1])
clf.set_params(penalty="l2")
clf.fit(X_new, y)
pred = clf.predict(X_new)
assert_greater(np.mean(pred == y), 0.7)
评论列表
文章目录