CART_Trainer.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:kdd99-scikit 作者: PENGZhaoqing 项目源码 文件源码
def train(self, training_set, training_target, fea_index):

        clf = tree.DecisionTreeClassifier(criterion="entropy", min_samples_split=30, class_weight="balanced")
        clf = clf.fit(training_set, training_target)

        class_names = np.unique([str(i) for i in training_target])
        feature_names = [attr_list[i] for i in fea_index]

        dot_data = tree.export_graphviz(clf, out_file=None,
                                        feature_names=feature_names,
                                        class_names=class_names,
                                        filled=True, rounded=True,
                                        special_characters=True)

        graph = pydotplus.graph_from_dot_data(dot_data)
        graph.write_pdf("output/tree-vis.pdf")
        joblib.dump(clf, 'output/CART.pkl')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号