def plot_learning_curve(estimators, X, y, cv=10, scoring=None, n_jobs=1):
figsize = (6.4 * len(estimators), 4.8)
fig, axes = plt.subplots(nrows=1, ncols=len(estimators), figsize=figsize)
if len(estimators) == 1:
axes = [axes]
for ax, estimator in zip(axes, estimators):
train_sizes, train_scores, test_scores = learning_curve(
estimator=estimator,
X=X,
y=y,
train_sizes=np.linspace(start=0.1, stop=1.0, num=10),
cv=cv,
scoring=None,
n_jobs=n_jobs,
verbose=1
)
xlabel = 'Number of training samples'
_plot_curve(
axes=ax,
train_sizes=train_sizes,
train_scores=train_scores,
test_scores=test_scores,
xlabel=xlabel,
scoring=scoring
)
ax.set_title(pipeline_name(estimator))
return fig
评论列表
文章目录