def plot_learning_curve(estimator, X, y, train_sizes=np.linspace(.1, 1.0, 5),
cv=None, n_jobs=1, ax=None):
'''
Plot the learning curve for `estimator`.
Parameters
----------
estimator : sklearn.Estimator
X : array-like
y : array-like
train_sizes : array-like
list of floats between 0 and 1
cv : int
n_jobs : int
ax : matplotlib.axes
'''
# http://scikit-learn.org/stable/auto_examples/model_selection/plot_learning_curve.html
if ax is None:
fig, ax = plt.subplots()
ax.set_xlabel("Training examples")
ax.set_ylabel("Score")
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes
)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
plt.grid()
plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.1,
color="r")
plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1, color="g")
plt.plot(train_sizes, train_scores_mean, 'o-', color="r",
label="Training score")
plt.plot(train_sizes, test_scores_mean, 'o-', color="g",
label="Cross-validation score")
plt.legend(loc="best")
return ax
评论列表
文章目录