def predictAndTestRandomForest(X, y, Xtest, ytest, estimators=10, criterion="gini", maxdepth=None, selectKBest=0):
"""
Trains a tree using the training data and tests it using the test data using K-fold cross validation
:param Xtr: The matrix of training feature vectors
:type Xtr: list
:param ytr: The labels corresponding to the training feature vectors
:type ytr: list
:param Xte: The matrix of test feature vectors
:type yte: list
:param estimators: The number of random trees to use in classification
:type estimators: int
:param criterion: The splitting criterion employed by the decision tree
:type criterion: str
:param maxdepth: The maximum depth the tree is allowed to grow
:type maxdepth: int
:param selectKBest: The number of best features to select
:type selectKBest: int
:return: Two lists of the validation and test accuracies across the 10 folds
"""
try:
predicted, predicted_test = [], []
# Define classifier and cross validation iterator
clf = ensemble.RandomForestClassifier(n_estimators=estimators, criterion=criterion, max_depth=maxdepth)
# Start the cross validation learning
X, y, Xtest, ytest = numpy.array(X), numpy.array(y), numpy.array(Xtest), numpy.array(ytest)
# Select K Best features if enabled
prettyPrint("Selecting %s best features from feature vectors" % selectKBest)
X_new = SelectKBest(chi2, k=selectKBest).fit_transform(X, y) if selectKBest > 0 else X
Xtest_new = SelectKBest(chi2, k=selectKBest).fit_transform(Xtest, ytest) if selectKBest > 0 else Xtest
# Fit model
prettyPrint("Fitting model")
clf.fit(X_new, y)
# Validate and test model
prettyPrint("Validating model using training data")
predicted = clf.predict(X_new)
prettyPrint("Testing model")
predicted_test = clf.predict(Xtest_new)
except Exception as e:
prettyPrintError(e)
return [], []
return predicted, predicted_test
评论列表
文章目录