def test_learning_curve():
X, y = make_classification(n_samples=30, n_features=1, n_informative=1,
n_redundant=0, n_classes=2,
n_clusters_per_class=1, random_state=0)
estimator = MockImprovingEstimator(20)
with warnings.catch_warnings(record=True) as w:
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10))
if len(w) > 0:
raise RuntimeError("Unexpected warning: %r" % w[0].message)
assert_equal(train_scores.shape, (10, 3))
assert_equal(test_scores.shape, (10, 3))
assert_array_equal(train_sizes, np.linspace(2, 20, 10))
assert_array_almost_equal(train_scores.mean(axis=1),
np.linspace(1.9, 1.0, 10))
assert_array_almost_equal(test_scores.mean(axis=1),
np.linspace(0.1, 1.0, 10))
评论列表
文章目录