ScikitLearners.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Aion 作者: aleisalem 项目源码 文件源码
def predictAndTestRandomForest(X, y, Xtest, ytest, estimators=10, criterion="gini", maxdepth=None, selectKBest=0):
    """
    Trains a tree using the training data and tests it using the test data using K-fold cross validation
    :param Xtr: The matrix of training feature vectors
    :type Xtr: list
    :param ytr: The labels corresponding to the training feature vectors
    :type ytr: list
    :param Xte: The matrix of test feature vectors
    :type yte: list
    :param estimators: The number of random trees to use in classification
    :type estimators: int
    :param criterion: The splitting criterion employed by the decision tree
    :type criterion: str
    :param maxdepth: The maximum depth the tree is allowed to grow
    :type maxdepth: int
    :param selectKBest: The number of best features to select
    :type selectKBest: int 
    :return: Two lists of the validation and test accuracies across the 10 folds
    """
    try:
        predicted, predicted_test = [], []
        # Define classifier and cross validation iterator
        clf = ensemble.RandomForestClassifier(n_estimators=estimators, criterion=criterion, max_depth=maxdepth)
        # Start the cross validation learning
        X, y, Xtest, ytest = numpy.array(X), numpy.array(y), numpy.array(Xtest), numpy.array(ytest)
        # Select K Best features if enabled
        prettyPrint("Selecting %s best features from feature vectors" % selectKBest)
        X_new = SelectKBest(chi2, k=selectKBest).fit_transform(X, y) if selectKBest > 0 else X
        Xtest_new = SelectKBest(chi2, k=selectKBest).fit_transform(Xtest, ytest) if selectKBest > 0 else Xtest
        # Fit model
        prettyPrint("Fitting model")
        clf.fit(X_new, y)
        # Validate and test model
        prettyPrint("Validating model using training data")
        predicted = clf.predict(X_new)
        prettyPrint("Testing model")
        predicted_test = clf.predict(Xtest_new)

    except Exception as e:
        prettyPrintError(e)
        return [], []

    return predicted, predicted_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号