def cv(method, xM, yV, alpha, n_folds=5, n_jobs=-1, grid_std=None, graph=True, shuffle=True):
"""
method can be 'Ridge', 'Lasso'
cross validation is performed so as to generate prediction output for all input molecules
Return
--------
yV_pred
"""
print(xM.shape, yV.shape)
clf = getattr(linear_model, method)(alpha=alpha)
kf_n_c = model_selection.KFold(n_splits=n_folds, shuffle=True)
kf_n = kf_n_c.split(xM)
yV_pred = model_selection.cross_val_predict(
clf, xM, yV, cv=kf_n, n_jobs=n_jobs)
if graph:
print('The prediction output using cross-validation is given by:')
jutil.cv_show(yV, yV_pred, grid_std=grid_std)
return yV_pred
评论列表
文章目录