def get_model():
if FLAGS.model == 'logistic':
return linear_model.LogisticRegressionCV(class_weight='balanced',
scoring='roc_auc',
n_jobs=FLAGS.n_jobs,
max_iter=10000, verbose=1)
elif FLAGS.model == 'random_forest':
return ensemble.RandomForestClassifier(n_estimators=100,
n_jobs=FLAGS.n_jobs,
class_weight='balanced',
verbose=1)
elif FLAGS.model == 'svm':
return grid_search.GridSearchCV(
estimator=svm.SVC(kernel='rbf', gamma='auto',
class_weight='balanced'),
param_grid={'C': np.logspace(-4, 4, 10)}, scoring='roc_auc',
n_jobs=FLAGS.n_jobs, verbose=1)
else:
raise ValueError('Unrecognized model %s' % FLAGS.model)
评论列表
文章目录