def train_tree_classifer(features, labels, model_output_path):
"""
train_tree_classifer will train a DecisionTree and write it out to a pdf file
features: 2D array of each input feature for each sample
labels: array of string labels classifying each sample
model_output_path: path for storing the trained tree model
"""
# save 20% of data for performance evaluation
X_train, X_test, y_train, y_test = cross_validation.train_test_split(features, labels, test_size=0.2)
param = [
{
"max_depth": [None, 10, 100, 1000, 10000]
}
]
dtree = tree.DecisionTreeClassifier(random_state=0)
# 10-fold cross validation, use 4 thread as each fold and each parameter set can be train in parallel
clf = grid_search.GridSearchCV(dtree, param,
cv=10, n_jobs=20, verbose=3)
clf.fit(X_train, y_train)
if os.path.exists(model_output_path):
joblib.dump(clf.best_estimator_, model_output_path)
else:
print("Cannot save trained tree model to {0}.".format(model_output_path))
dot_data = tree.export_graphviz(clf.best_estimator_, out_file=None)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf('best_tree.pdf')
print("\nBest parameters set:")
print(clf.best_params_)
y_predict=clf.predict(X_test)
labels=sorted(list(set(labels)))
print("\nConfusion matrix:")
print("Labels: {0}\n".format(",".join(labels)))
print(confusion_matrix(y_test, y_predict, labels=labels))
print("\nClassification report:")
print(classification_report(y_test, y_predict))
评论列表
文章目录