insights.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:menrva 作者: amirziai 项目源码 文件源码
def decision_tree(X, y, regression, max_depth=3):
    from sklearn.tree import export_graphviz
    from sklearn.externals.six import StringIO  
    from IPython.core.pylabtools import figsize
    from IPython.display import Image
    figsize(12.5, 6)
    import pydot

    if regression:
        clf = DecisionTreeRegressor(max_depth=max_depth)
    else:
        clf = DecisionTreeClassifier(max_depth=max_depth)

    clf.fit(X, y)
    dot_data = StringIO()  
    export_graphviz(clf, out_file=dot_data, feature_names=list(X.columns),
                    filled=True, rounded=True,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())  
    return Image(graph.create_png())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号