CCIT.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:CCIT 作者: rajatsen91 项目源码 文件源码
def cross_validate(classifier, n_folds = 5):
    '''Custom cross-validation module I always use '''
    train_X = classifier['train_X']
    train_y = classifier['train_y']
    model = classifier['model']
    score = 0.0

    skf = KFold(n_splits = n_folds)
    for train_index, test_index in skf.split(train_X):
        X_train, X_test = train_X[train_index], train_X[test_index]
        y_train, y_test = train_y[train_index], train_y[test_index]
        clf = model.fit(X_train,y_train)
        pred = clf.predict_proba(X_test)[:,1]
        #print 'cross', roc_auc_score(y_test,pred)
        score = score + roc_auc_score(y_test,pred)

    return score/n_folds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号