utils.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:time_series_modeling 作者: rheineke 项目源码 文件源码
def plot_validation_curve(estimators, X, y, cv=10, **kwargs):
    figsize = (6.4 * len(estimators), 4.8)
    fig, axes = plt.subplots(nrows=1, ncols=len(estimators), figsize=figsize)
    param_range = [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]

    if len(estimators) == 1:
        axes = [axes]

    for ax, estimator in zip(axes, estimators):
        train_scores, test_scores = validation_curve(
            estimator=estimator,
            X=X,
            y=y,
            param_name='clf__C',
            param_range=param_range,
            cv=cv,
            **kwargs)
        xlabel = 'Parameter C'
        _plot_curve(ax, param_range, train_scores, test_scores, xlabel, 'log')
        ax.set_title(pipeline_name(estimator))

    # fig.tight_layout(pad=1.08, h_pad=None, w_pad=None, rect=None)

    return fig
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号