classification.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Oedipus 作者: tum-i22 项目源码 文件源码
def classifyTreeKFold(X, y, kFold=2, splitCriterion="gini", maxDepth=0, visualizeTree=False):
    """ Classifies data using CART and K-Fold cross validation """
    try:
        groundTruthLabels, predictedLabels = [], []
        accuracyRates = [] # Meant to hold the accuracy rates
        # Split data into training and test datasets
        trainingDataset, testDataset = [], []
        trainingLabels, testLabels = [], []
        accuracyRates = []
        probabilities = []
        timings = []
        kFoldValidator = KFold(n=len(X), n_folds=kFold, shuffle=False)
        currentFold = 1
        for trainingIndices, testIndices in kFoldValidator:
            # Prepare the training and testing datasets
            for trIndex in trainingIndices:
                trainingDataset.append(X[trIndex])
                trainingLabels.append(y[trIndex])
            for teIndex in testIndices:
                testDataset.append(X[teIndex])
                testLabels.append(y[teIndex])
            # Perform classification
            startTime = time.time()
            cartClassifier = tree.DecisionTreeClassifier(criterion=splitCriterion, max_depth=maxDepth)
            prettyPrint("Training a CART tree for classification using \"%s\" and maximum depth of %s" % (splitCriterion, maxDepth), "debug")
            cartClassifier.fit(numpy.array(trainingDataset), numpy.array(trainingLabels))
            prettyPrint("Submitting the test samples", "debug")
            predicted = cartClassifier.predict(testDataset)
            endTime = time.time()
            # Add that to the groundTruthLabels and predictedLabels matrices
            groundTruthLabels.append(testLabels)
            predictedLabels.append(predicted)
            # Compare the predicted and ground truth and append result to list
            accuracyRates.append(round(metrics.accuracy_score(predicted, testLabels), 2))
            # Also append the probability estimates
            probs = cartClassifier.predict_proba(testDataset)
            probabilities.append(probs)
            timings.append(endTime-startTime) # Keep track of performance
            if visualizeTree:
                # Visualize the tree
                dot_data = StringIO() 
                tree.export_graphviz(cartClassifier, out_file=dot_data) 
                graph = pydot.graph_from_dot_data(dot_data.getvalue()) 
                prettyPrint("Saving learned CART to \"tritonTree_%s.pdf\"" % currentFold, "debug")
                graph.write_pdf("tritonTree_%s.pdf" % currentFold)

            trainingDataset, trainingLabels = [], []
            testDataset, testLabels = [], []
            currentFold += 1 

    except Exception as e:
        prettyPrint("Error encountered in \"classifyTreeKFold\": %s" % e, "error")
        return [], [], []

    return accuracyRates, probabilities, timings, groundTruthLabels, predictedLabels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号