def classify(train=None, test=None, data=None, res_dir="res/", disp=True, outfilename=None):
"""Description of compare
compare multiple classifier and display the best one
"""
utils.print_success("Comparison of differents classifiers")
if data is not None:
train_features = data["train_features"]
train_groundtruths = data["train_groundtruths"]
test_features = data["test_features"]
test_groundtruths = data["test_groundtruths"]
else:
train = utils.abs_path_file(train)
test = utils.abs_path_file(test)
train_features, train_groundtruths = read_file(train)
test_features, test_groundtruths = read_file(test)
if not utils.create_dir(res_dir):
res_dir = utils.abs_path_dir(res_dir)
classifiers = {
"RandomForest": RandomForestClassifier(n_jobs=-1)
# "RandomForest": RandomForestClassifier(n_estimators=5),
# "KNeighbors":KNeighborsClassifier(3),
# "GaussianProcess":GaussianProcessClassifier(1.0 * RBF(1.0), warm_start=True),
# "DecisionTree":DecisionTreeClassifier(max_depth=5),
# "MLP":MLPClassifier(),
# "AdaBoost":AdaBoostClassifier(),
# "GaussianNB":GaussianNB(),
# "QDA":QuadraticDiscriminantAnalysis(),
# "SVM":SVC(kernel="linear", C=0.025),
# "GradientBoosting":GradientBoostingClassifier(),
# "ExtraTrees":ExtraTreesClassifier(),
# "LogisticRegression":LogisticRegression(),
# "LinearDiscriminantAnalysis":LinearDiscriminantAnalysis()
}
for key in classifiers:
utils.print_success(key)
clf = classifiers[key]
utils.print_info("\tFit")
clf.fit(train_features, train_groundtruths)
utils.print_info("\tPredict")
predictions = clf.predict(test_features)
return predictions
评论列表
文章目录