classify.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:ISM2017 作者: ybayle 项目源码 文件源码
def train_test(train, test, res_dir="res/", disp=True, outfilename=None):
    """Description of compare
    compare multiple classifier and display the best one
    """
    utils.print_success("Comparison of differents classifiers")
    if train is not None and test is not None:
        train_features = []
        test_features = []
        train_groundtruths = []
        test_groundtruths = []
        for elem in train:
            train_groundtruths.append(elem)
            train_features.append(train[elem])
        for elem in test:
            test_groundtruths.append(elem)
            test_features.append(test[elem])
    else:
        utils.print_error("No valid data provided.")
    res_dir = utils.create_dir(res_dir)
    classifiers = {
        # "RandomForest": RandomForestClassifier(n_estimators=5),
        "KNeighbors":KNeighborsClassifier(1),
        # "GaussianProcess":GaussianProcessClassifier(1.0 * RBF(1.0), warm_start=True),
        # "DecisionTree":DecisionTreeClassifier(max_depth=5),
        # "MLP":MLPClassifier(),
        # "AdaBoost":AdaBoostClassifier(),
        # "GaussianNB":GaussianNB(),
        # "QDA":QuadraticDiscriminantAnalysis(),
        # "SVM":SVC(kernel="linear", C=0.025),
        # "GradientBoosting":GradientBoostingClassifier(),
        # "ExtraTrees":ExtraTreesClassifier(),
        # "LogisticRegression":LogisticRegression(),
        # "LinearDiscriminantAnalysis":LinearDiscriminantAnalysis()
    }
    for key in classifiers:
        utils.print_success(key)
        clf = classifiers[key]
        utils.print_info("\tFit")
        clf.fit(train_features, train_groundtruths)
        utils.print_info("\tPredict")
        predictions = clf.predict(test_features)

        print("Precision weighted\t" + str(precision_score(test_groundtruths, predictions, average='weighted')))
        print("Recall weighted\t" + str(recall_score(test_groundtruths, predictions, average='weighted')))
        print("F1 weighted\t" + str(f1_score(test_groundtruths, predictions, average='weighted')))
        # print("Precision weighted\t" + str(precision_score(test_groundtruths, predictions, average=None)))
        # print("Recall weighted\t" + str(recall_score(test_groundtruths, predictions, average=None)))
        # print("f1 weighted\t" + str(f1_score(test_groundtruths, predictions, average=None)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号