def classifyTree(Xtr, ytr, Xte, yte, splitCriterion="gini", maxDepth=0, visualizeTree=False):
""" Classifies data using CART """
try:
accuracyRate, probabilities, timing = 0.0, [], 0.0
# Perform classification
cartClassifier = tree.DecisionTreeClassifier(criterion=splitCriterion, max_depth=maxDepth)
startTime = time.time()
prettyPrint("Training a CART tree for classification using \"%s\" and maximum depth of %s" % (splitCriterion, maxDepth), "debug")
cartClassifier.fit(numpy.array(Xtr), numpy.array(ytr))
prettyPrint("Submitting the test samples", "debug")
predicted = cartClassifier.predict(Xte)
endTime = time.time()
# Compare the predicted and ground truth and append result to list
accuracyRate = round(metrics.accuracy_score(predicted, yte), 2)
# Also append the probability estimates
probs = cartClassifier.predict_proba(Xte)
probabilities.append(probs)
timing = endTime-startTime # Keep track of performance
if visualizeTree:
# Visualize the tree
dot_data = StringIO()
tree.export_graphviz(cartClassifier, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
prettyPrint("Saving learned CART to \"tritonTree_%s.pdf\"" % getTimestamp(), "debug")
graph.write_pdf("tree_%s.pdf" % getTimestamp())
except Exception as e:
prettyPrint("Error encountered in \"classifyTree\": %s" % e, "error")
return accuracyRate, timing, probabilities, predicted
评论列表
文章目录