ScikitLearners.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Aion 作者: aleisalem 项目源码 文件源码
def predictAndTestKNN(X, y, Xtest, ytest, K=10, selectKBest=0):
    """
    Trains a K-NN using the training data and tests it using the test data using K-fold cross validation
    :type X: list
    :param y: The labels corresponding to the training feature vectors
    :type y: list
    :param Xtest: The matrix of test feature vectors
    :type Xtest: list
    :param ytest: The labels corresponding to the test feature vectors
    :type ytest: list
    :param K: The number of nearest neighbors to consider in classification
    :type K: int
    :param selectKBest: The number of best features to select
    :type selectKBest: int
    :return: Two lists of the validation and test accuracies across the k-folds
    """
    try:
        predicted, predicted_test = [], []
        # Define classifier and cross validation iterator
        clf = neighbors.KNeighborsClassifier(n_neighbors=K)
        # Start the cross validation learning
        X, y, Xtest, ytest = numpy.array(X), numpy.array(y), numpy.array(Xtest), numpy.array(ytest)
        # Select K Best features if enabled
        prettyPrint("Selecting %s best features from feature vectors" % selectKBest)
        X_new = SelectKBest(chi2, k=selectKBest).fit_transform(X, y) if selectKBest > 0 else X
        Xtest_new = SelectKBest(chi2, k=selectKBest).fit_transform(Xtest, ytest) if selectKBest else Xtest
        # Fit model
        prettyPrint("Fitting model")
        clf.fit(X_new, y)
        # Validate and test model
        prettyPrint("Validating model using training data")
        predicted = clf.predict(X_new)
        prettyPrint("Testing model")
        predicted_test = clf.predict(Xtest_new)

    except Exception as e:
        prettyPrintError(e)
        return [], []

    return predicted, predicted_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号