def gridSearchPipeline(pipeline, paramsGrid, Xtrain, Ytrain, **cvParams):
print("Grid Searching pipeline:")
print(pipeline)
# use 5-fold stratified cross-validation by default to maintain
# consistent class balance across training and testing
if 'cv' not in cvParams:
# print "Ytrain: ", Ytrain
# numClasses = len(np.unique(Ytrain))
# examplesPerClass = len(Ytrain) / numClasses
# nFolds = max(5, examplesPerClass / 5)
# if nFolds < 5:
# if True:
# r, c = Ytrain.shape
# print "tiny Ytrain size: (%d, %d)" % Ytrain.shape # (r, c)
# for row in Ytrain: print row
# cvParams['cv'] = StratifiedKFold(Ytrain, n_folds=nFolds)
cvParams['cv'] = StratifiedKFold(Ytrain, n_folds=5)
cv = GridSearchCV(pipeline, paramsGrid, **cvParams)
cv.fit(Xtrain, Ytrain)
return cv
评论列表
文章目录