classifier.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:false-friends 作者: pln-fing-udelar 项目源码 文件源码
def classify_with_cross_validation(X, y, clf, n_folds=5):
    cv_matrices = []
    cv_measures = collections.defaultdict(list)  # FIXME: use collections.OrderedDict too

    logging.info("classifying and predicting with cross validation")
    skf = cross_validation.StratifiedKFold(y, n_folds=n_folds)
    for train_indices, test_indices in skf:
        X_train = X[train_indices]
        X_test = X[test_indices]
        y_train = y[train_indices]
        y_test = y[test_indices]

        clf.fit(X_train, y_train)
        y_predicted = clf.predict(X_test)

        confusion_matrix = metrics.confusion_matrix(y_test, y_predicted).flatten()
        cv_matrices.append(confusion_matrix)
        for measure_name, measure_value in calculate_measures(*confusion_matrix).items():
            cv_measures[measure_name].append(measure_value)

    for measure_name, measure_values in cv_measures.items():
        mean = np.mean(measure_values)
        delta = np.std(measure_values) * 1.96 / math.sqrt(n_folds)  # 95% of confidence
        cv_measures[measure_name] = (mean, delta)

    return cv_measures


# noinspection PyPep8Naming
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号