def getEstimator(es):
estimator = None
algo = es.ml_algorithm.upper()
if algo == 'NAIVEBAYESGAUSSIAN':
estimator = naive_bayes.GaussianNB()
elif algo == 'SVM':
estimator = svm.SVC(kernel=es.svmKernel, degree = 3, C = 0.1, random_state=es.random_seed)
elif algo == 'RF':
estimator = RandomForestClassifier(n_estimators=100, random_state=es.random_seed)
elif algo == 'DECISIONTREE':
estimator = DecisionTreeClassifier(random_state=es.random_seed)
elif algo == 'RANDOM':
estimator = DummyClassifier(random_state=es.random_seed)
else:
print("Please enter correct estimator (NaiveBayesGaussian/SVM/RF/DecisionTree)")
#TODO: add regression?
return estimator
评论列表
文章目录