def train(self, training_set, training_target, fea_index):
clf = tree.DecisionTreeClassifier(criterion="entropy", min_samples_split=30, class_weight="balanced")
clf = clf.fit(training_set, training_target)
class_names = np.unique([str(i) for i in training_target])
feature_names = [attr_list[i] for i in fea_index]
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=feature_names,
class_names=class_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("output/tree-vis.pdf")
joblib.dump(clf, 'output/CART.pkl')
评论列表
文章目录