def classify_with_cross_validation(X, y, clf, n_folds=5):
cv_matrices = []
cv_measures = collections.defaultdict(list) # FIXME: use collections.OrderedDict too
logging.info("classifying and predicting with cross validation")
skf = cross_validation.StratifiedKFold(y, n_folds=n_folds)
for train_indices, test_indices in skf:
X_train = X[train_indices]
X_test = X[test_indices]
y_train = y[train_indices]
y_test = y[test_indices]
clf.fit(X_train, y_train)
y_predicted = clf.predict(X_test)
confusion_matrix = metrics.confusion_matrix(y_test, y_predicted).flatten()
cv_matrices.append(confusion_matrix)
for measure_name, measure_value in calculate_measures(*confusion_matrix).items():
cv_measures[measure_name].append(measure_value)
for measure_name, measure_values in cv_measures.items():
mean = np.mean(measure_values)
delta = np.std(measure_values) * 1.96 / math.sqrt(n_folds) # 95% of confidence
cv_measures[measure_name] = (mean, delta)
return cv_measures
# noinspection PyPep8Naming
评论列表
文章目录