def gs_param( model, X, y, param_grid, n_splits=5, shuffle=True, n_jobs=-1, graph=False):
"""
gs = gs_param( model, X, y, param_grid, n_splits=5, shuffle=True, n_jobs=-1)
Inputs
======
model = svm.SVC(), or linear_model.LinearRegression(), for example
param = {"C": np.logspace(-2,2,5)}
"""
#print(xM.shape, yVc.shape)
kf5_c = model_selection.KFold( n_splits=n_splits, shuffle=shuffle)
gs = model_selection.GridSearchCV( model, param_grid, cv=kf5_c, n_jobs=n_jobs)
gs.fit( X, y)
if graph:
plt.plot( gs.cv_results_["mean_train_score"], label='E[Train]')
plt.plot( gs.cv_results_["mean_test_score"], label='E[Test]')
plt.legend(loc=0)
plt.grid()
return gs
评论列表
文章目录