def validate(data, labels):
'''
Ten-fold cross-validation with stratified sampling.
'''
accuracy_scores = []
precision_scores = []
recall_scores = []
f1_scores = []
sss = StratifiedShuffleSplit(n_splits=10)
for train_index, test_index in sss.split(data, labels):
x_train, x_test = data[train_index], data[test_index]
y_train, y_test = labels[train_index], labels[test_index]
clf.fit(x_train, y_train)
y_pred = clf.predict(x_test)
accuracy_scores.append(accuracy_score(y_test, y_pred))
precision_scores.append(precision_score(y_test, y_pred))
recall_scores.append(recall_score(y_test, y_pred))
f1_scores.append(f1_score(y_test, y_pred))
print('Accuracy', np.mean(accuracy_scores))
print('Precision', np.mean(precision_scores))
print('Recall', np.mean(recall_scores))
print('F1-measure', np.mean(f1_scores))
评论列表
文章目录