def test_graphviz_errors():
# Check for errors of export_graphviz
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
clf.fit(X, y)
# Check feature_names error
out = StringIO()
assert_raises(IndexError, export_graphviz, clf, out, feature_names=[])
# Check class_names error
out = StringIO()
assert_raises(IndexError, export_graphviz, clf, out, class_names=[])
评论列表
文章目录