decision_tree_classifier.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:SLIC_cityscapes 作者: wpqmanu 项目源码 文件源码
def decision_tree_classifier(all_feature_data):
    input_data=np.asarray(all_feature_data[0])
    label=np.asarray(all_feature_data[1])

    data=input_data[:,:]
    # data=sklearn.preprocessing.normalize(data,axis=0)

    # clf = DecisionTreeClassifier(criterion="gini",
                                 # splitter="best",
                                 # max_features=None,
                                 # max_depth=5,
                                 # min_samples_leaf=1,
                                 # min_samples_split=2,
                                 # class_weight=None
                                 # )
    clf = DecisionTreeClassifier()
    fit_clf=clf.fit(data,label)

    result=fit_clf.predict(data)
    accuracy=float(np.sum(result==label))/len(label)
    print "Training accuracy is " + str(accuracy)
    with open("cityscapes.dot", 'w') as f:
        f = tree.export_graphviz(clf, out_file=f)

    # dot_data = StringIO()
    # tree.export_graphviz(clf, out_file=dot_data)
    # graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    # graph.write_pdf("cityscapes.pdf")


    # scores = cross_val_score(clf, data, label, cv=10)
    # print "Cross validation score is "+ str(scores.mean())

    return fit_clf
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号