def constructModel(corpus, classList, features, modelOutput):
"""
Trains a Decision Tree model on the test corpus.
Args:
corpus: A list of lists, containing the GC content, coverage, and class number.
classList: A list of class names.
features: List of variables used by each contig.
modelOutput: Location to save model as GraphViz DOT, or False to save no model.
Returns:
classifier: A DecisionTreeClassifier object that has been trained on the test corpus.
"""
corpus.sort() # just in case
X = []
Y = []
for item in corpus:
X.append(item[:-1]) # all but the last item
Y.append(item[-1]) # only the last item
X_train, X_test, Y_train, Y_test = mscv.train_test_split(X, Y, test_size=0.3, random_state=0)
# TODO: implement classifier testing and comparison, now only baggingClassifier is used as per paper
#treeClassifier = tree.DecisionTreeClassifier()
#treeClassifier = treeClassifier.fit(X_train, Y_train)
#click.echo("Decision tree classifier built, score is %s out of 1.00" % treeClassifier.score(X_test, Y_test))
baggingClassifier = ensemble.BaggingClassifier()
baggingClassifier = baggingClassifier.fit(X_train, Y_train)
click.echo("Bagging classifier built, score is %s out of 1.00" % baggingClassifier.score(X_test, Y_test))
#forestClassifier = ensemble.RandomForestClassifier(n_estimators=10)
#forestClassifier = forestClassifier.fit(X_train, Y_train)
#click.echo("Random forest classifier built, score is %s out of 1.00" % forestClassifier.score(X_test, Y_test))
#adaClassifier = ensemble.AdaBoostClassifier(n_estimators=100)
#adaClassifier = adaClassifier.fit(X_train, Y_train)
#click.echo("AdaBoost classifier built, score is %s out of 1.00" % adaClassifier.score(X_test, Y_test))
#gradientClassifier = ensemble.GradientBoostingClassifier(n_estimators=100)
#gradientClassifier = gradientClassifier.fit(X_train, Y_train)
#click.echo("Gradient tree boosting classifier built, score is %s out of 1.00" % gradientClassifier.score(X_test, Y_test))
if modelOutput:
with open(modelOutput, 'w') as dotfile:
tree.export_graphviz(baggingClassifier, out_file=dotfile, feature_names=features,
class_names=classList, filled=True, rounded=True, special_characters=True)
return baggingClassifier
评论列表
文章目录