def decision_tree_classifier(all_feature_data):
input_data=np.asarray(all_feature_data[0])
label=np.asarray(all_feature_data[1])
data=input_data[:,:]
# data=sklearn.preprocessing.normalize(data,axis=0)
# clf = DecisionTreeClassifier(criterion="gini",
# splitter="best",
# max_features=None,
# max_depth=5,
# min_samples_leaf=1,
# min_samples_split=2,
# class_weight=None
# )
clf = DecisionTreeClassifier()
fit_clf=clf.fit(data,label)
result=fit_clf.predict(data)
accuracy=float(np.sum(result==label))/len(label)
print "Training accuracy is " + str(accuracy)
with open("cityscapes.dot", 'w') as f:
f = tree.export_graphviz(clf, out_file=f)
# dot_data = StringIO()
# tree.export_graphviz(clf, out_file=dot_data)
# graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
# graph.write_pdf("cityscapes.pdf")
# scores = cross_val_score(clf, data, label, cv=10)
# print "Cross validation score is "+ str(scores.mean())
return fit_clf
decision_tree_classifier.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录