def test_threshold_without_refitting():
"""Test that the threshold can be set without refitting the model."""
clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0)
model = SelectFromModel(clf, threshold=0.1)
model.fit(data, y)
X_transform = model.transform(data)
# Set a higher threshold to filter out more features.
model.threshold = 1.0
assert_greater(X_transform.shape[1], model.transform(data).shape[1])
评论列表
文章目录