classification.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Oedipus 作者: tum-i22 项目源码 文件源码
def classifyTree(Xtr, ytr, Xte, yte, splitCriterion="gini", maxDepth=0, visualizeTree=False):
    """ Classifies data using CART """
    try:
        accuracyRate, probabilities, timing = 0.0, [], 0.0
        # Perform classification
        cartClassifier = tree.DecisionTreeClassifier(criterion=splitCriterion, max_depth=maxDepth)
        startTime = time.time()
        prettyPrint("Training a CART tree for classification using \"%s\" and maximum depth of %s" % (splitCriterion, maxDepth), "debug")
        cartClassifier.fit(numpy.array(Xtr), numpy.array(ytr))
        prettyPrint("Submitting the test samples", "debug")
        predicted = cartClassifier.predict(Xte)
        endTime = time.time()
        # Compare the predicted and ground truth and append result to list
        accuracyRate = round(metrics.accuracy_score(predicted, yte), 2)
        # Also append the probability estimates
        probs = cartClassifier.predict_proba(Xte)
        probabilities.append(probs)
        timing = endTime-startTime # Keep track of performance
        if visualizeTree:
            # Visualize the tree
            dot_data = StringIO()
            tree.export_graphviz(cartClassifier, out_file=dot_data)
            graph = pydot.graph_from_dot_data(dot_data.getvalue())
            prettyPrint("Saving learned CART to \"tritonTree_%s.pdf\"" % getTimestamp(), "debug")
            graph.write_pdf("tree_%s.pdf" % getTimestamp())

    except Exception as e:
        prettyPrint("Error encountered in \"classifyTree\": %s" % e, "error")

    return accuracyRate, timing, probabilities, predicted
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号