utils.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:time_series_modeling 作者: rheineke 项目源码 文件源码
def plot_learning_curve(estimators, X, y, cv=10, scoring=None, n_jobs=1):
    figsize = (6.4 * len(estimators), 4.8)
    fig, axes = plt.subplots(nrows=1, ncols=len(estimators), figsize=figsize)

    if len(estimators) == 1:
        axes = [axes]

    for ax, estimator in zip(axes, estimators):
        train_sizes, train_scores, test_scores = learning_curve(
            estimator=estimator,
            X=X,
            y=y,
            train_sizes=np.linspace(start=0.1, stop=1.0, num=10),
            cv=cv,
            scoring=None,
            n_jobs=n_jobs,
            verbose=1
        )
        xlabel = 'Number of training samples'
        _plot_curve(
            axes=ax,
            train_sizes=train_sizes,
            train_scores=train_scores,
            test_scores=test_scores,
            xlabel=xlabel,
            scoring=scoring
        )
        ax.set_title(pipeline_name(estimator))
    return fig
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号