def train_test(train, test, res_dir="res/", disp=True, outfilename=None):
"""Description of compare
compare multiple classifier and display the best one
"""
utils.print_success("Comparison of differents classifiers")
if train is not None and test is not None:
train_features = []
test_features = []
train_groundtruths = []
test_groundtruths = []
for elem in train:
train_groundtruths.append(elem)
train_features.append(train[elem])
for elem in test:
test_groundtruths.append(elem)
test_features.append(test[elem])
else:
utils.print_error("No valid data provided.")
res_dir = utils.create_dir(res_dir)
classifiers = {
# "RandomForest": RandomForestClassifier(n_estimators=5),
"KNeighbors":KNeighborsClassifier(1),
# "GaussianProcess":GaussianProcessClassifier(1.0 * RBF(1.0), warm_start=True),
# "DecisionTree":DecisionTreeClassifier(max_depth=5),
# "MLP":MLPClassifier(),
# "AdaBoost":AdaBoostClassifier(),
# "GaussianNB":GaussianNB(),
# "QDA":QuadraticDiscriminantAnalysis(),
# "SVM":SVC(kernel="linear", C=0.025),
# "GradientBoosting":GradientBoostingClassifier(),
# "ExtraTrees":ExtraTreesClassifier(),
# "LogisticRegression":LogisticRegression(),
# "LinearDiscriminantAnalysis":LinearDiscriminantAnalysis()
}
for key in classifiers:
utils.print_success(key)
clf = classifiers[key]
utils.print_info("\tFit")
clf.fit(train_features, train_groundtruths)
utils.print_info("\tPredict")
predictions = clf.predict(test_features)
print("Precision weighted\t" + str(precision_score(test_groundtruths, predictions, average='weighted')))
print("Recall weighted\t" + str(recall_score(test_groundtruths, predictions, average='weighted')))
print("F1 weighted\t" + str(f1_score(test_groundtruths, predictions, average='weighted')))
# print("Precision weighted\t" + str(precision_score(test_groundtruths, predictions, average=None)))
# print("Recall weighted\t" + str(recall_score(test_groundtruths, predictions, average=None)))
# print("f1 weighted\t" + str(f1_score(test_groundtruths, predictions, average=None)))
评论列表
文章目录