def plot_validation_curve(estimators, X, y, cv=10, **kwargs):
figsize = (6.4 * len(estimators), 4.8)
fig, axes = plt.subplots(nrows=1, ncols=len(estimators), figsize=figsize)
param_range = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]
if len(estimators) == 1:
axes = [axes]
for ax, estimator in zip(axes, estimators):
train_scores, test_scores = validation_curve(
estimator=estimator,
X=X,
y=y,
param_name='clf__C',
param_range=param_range,
cv=cv,
**kwargs)
xlabel = 'Parameter C'
_plot_curve(ax, param_range, train_scores, test_scores, xlabel, 'log')
ax.set_title(pipeline_name(estimator))
# fig.tight_layout(pad=1.08, h_pad=None, w_pad=None, rect=None)
return fig
评论列表
文章目录