learn.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:extract 作者: dblalock 项目源码 文件源码
def gridSearchPipeline(pipeline, paramsGrid, Xtrain, Ytrain, **cvParams):
    print("Grid Searching pipeline:")
    print(pipeline)

    # use 5-fold stratified cross-validation by default to maintain
    # consistent class balance across training and testing
    if 'cv' not in cvParams:
        # print "Ytrain: ", Ytrain
        # numClasses = len(np.unique(Ytrain))
        # examplesPerClass = len(Ytrain) / numClasses
        # nFolds = max(5, examplesPerClass / 5)
        # if nFolds < 5:
        # if True:
            # r, c = Ytrain.shape
            # print "tiny Ytrain size: (%d, %d)" % Ytrain.shape # (r, c)
            # for row in Ytrain: print row
        # cvParams['cv'] = StratifiedKFold(Ytrain, n_folds=nFolds)
        cvParams['cv'] = StratifiedKFold(Ytrain, n_folds=5)

    cv = GridSearchCV(pipeline, paramsGrid, **cvParams)
    cv.fit(Xtrain, Ytrain)
    return cv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号