def classifyTreeKFold(X, y, kFold=2, splitCriterion="gini", maxDepth=0, visualizeTree=False):
""" Classifies data using CART and K-Fold cross validation """
try:
groundTruthLabels, predictedLabels = [], []
accuracyRates = [] # Meant to hold the accuracy rates
# Split data into training and test datasets
trainingDataset, testDataset = [], []
trainingLabels, testLabels = [], []
accuracyRates = []
probabilities = []
timings = []
kFoldValidator = KFold(n=len(X), n_folds=kFold, shuffle=False)
currentFold = 1
for trainingIndices, testIndices in kFoldValidator:
# Prepare the training and testing datasets
for trIndex in trainingIndices:
trainingDataset.append(X[trIndex])
trainingLabels.append(y[trIndex])
for teIndex in testIndices:
testDataset.append(X[teIndex])
testLabels.append(y[teIndex])
# Perform classification
startTime = time.time()
cartClassifier = tree.DecisionTreeClassifier(criterion=splitCriterion, max_depth=maxDepth)
prettyPrint("Training a CART tree for classification using \"%s\" and maximum depth of %s" % (splitCriterion, maxDepth), "debug")
cartClassifier.fit(numpy.array(trainingDataset), numpy.array(trainingLabels))
prettyPrint("Submitting the test samples", "debug")
predicted = cartClassifier.predict(testDataset)
endTime = time.time()
# Add that to the groundTruthLabels and predictedLabels matrices
groundTruthLabels.append(testLabels)
predictedLabels.append(predicted)
# Compare the predicted and ground truth and append result to list
accuracyRates.append(round(metrics.accuracy_score(predicted, testLabels), 2))
# Also append the probability estimates
probs = cartClassifier.predict_proba(testDataset)
probabilities.append(probs)
timings.append(endTime-startTime) # Keep track of performance
if visualizeTree:
# Visualize the tree
dot_data = StringIO()
tree.export_graphviz(cartClassifier, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
prettyPrint("Saving learned CART to \"tritonTree_%s.pdf\"" % currentFold, "debug")
graph.write_pdf("tritonTree_%s.pdf" % currentFold)
trainingDataset, trainingLabels = [], []
testDataset, testLabels = [], []
currentFold += 1
except Exception as e:
prettyPrint("Error encountered in \"classifyTreeKFold\": %s" % e, "error")
return [], [], []
return accuracyRates, probabilities, timings, groundTruthLabels, predictedLabels
评论列表
文章目录