def visualize_tree(clf, feature_names, class_names, output_file,
method='pdf'):
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True,
impurity=False)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
if method == 'pdf':
graph.write_pdf(output_file + ".pdf")
elif method == 'inline':
Image(graph.create_png())
return graph
# An example using the iris dataset
评论列表
文章目录