decision_tree_manual_classifier.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:SLIC_cityscapes 作者: wpqmanu 项目源码 文件源码
def decision_tree_manual_classifier(all_feature_data):
    input_data=np.asarray(all_feature_data[0])
    label=np.asarray(all_feature_data[1])

    data_for_manual_tree=[]
    for row_index in range(len(all_feature_data[0])):
        current_row=all_feature_data[0][row_index]+[all_feature_data[1][row_index]]
        data_for_manual_tree.append(current_row)

    # # splitting rule
    # set1, set2 = divideset(data_for_manual_tree, 1, 14)
    # # print(set1)
    # print(uniquecounts(set1))
    # print("")
    # # print(set2)
    # print(uniquecounts(set2))
    #
    # print entropy(set1)
    # print entropy(set2)
    # print entropy(data_for_manual_tree)

    tree = buildtree(data_for_manual_tree)


    data=input_data[:,:]
    # data=sklearn.preprocessing.normalize(data,axis=0)

    # clf = DecisionTreeClassifier(criterion="gini",
                                 # splitter="best",
                                 # max_features=None,
                                 # max_depth=5,
                                 # min_samples_leaf=1,
                                 # min_samples_split=2,
                                 # class_weight=None
                                 # )

    for row_index in range(len(all_feature_data[0])):
        to_be_predicted_data=all_feature_data[0][row_index]
        predicted_label=classify(to_be_predicted_data,tree)

    clf = DecisionTreeClassifier()
    fit_clf=clf.fit(data,label)

    result=fit_clf.predict(data)
    accuracy=float(np.sum(result==label))/len(label)
    print "Training accuracy is " + str(accuracy)
    with open("cityscapes.dot", 'w') as f:
        f = tree.export_graphviz(clf, out_file=f)

    return fit_clf
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号