kgrid.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:jamespy_py3 作者: jskDr 项目源码 文件源码
def gs_param( model, X, y, param_grid, n_splits=5, shuffle=True, n_jobs=-1, graph=False):
    """
    gs = gs_param( model, X, y, param_grid, n_splits=5, shuffle=True, n_jobs=-1)

    Inputs
    ======
    model = svm.SVC(), or linear_model.LinearRegression(), for example
    param = {"C": np.logspace(-2,2,5)}
    """
    #print(xM.shape, yVc.shape)
    kf5_c = model_selection.KFold( n_splits=n_splits, shuffle=shuffle)
    gs = model_selection.GridSearchCV( model, param_grid, cv=kf5_c, n_jobs=n_jobs)
    gs.fit( X, y)

    if graph:
        plt.plot( gs.cv_results_["mean_train_score"], label='E[Train]')
        plt.plot( gs.cv_results_["mean_test_score"], label='E[Test]')
        plt.legend(loc=0)
        plt.grid()

    return gs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号