def decision_tree(X, y, regression, max_depth=3):
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from IPython.core.pylabtools import figsize
from IPython.display import Image
figsize(12.5, 6)
import pydot
if regression:
clf = DecisionTreeRegressor(max_depth=max_depth)
else:
clf = DecisionTreeClassifier(max_depth=max_depth)
clf.fit(X, y)
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data, feature_names=list(X.columns),
filled=True, rounded=True,)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
return Image(graph.create_png())
评论列表
文章目录