models.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:color-features 作者: skearnes 项目源码 文件源码
def get_model():
    if FLAGS.model == 'logistic':
        return linear_model.LogisticRegressionCV(class_weight='balanced',
                                                 scoring='roc_auc',
                                                 n_jobs=FLAGS.n_jobs,
                                                 max_iter=10000, verbose=1)
    elif FLAGS.model == 'random_forest':
        return ensemble.RandomForestClassifier(n_estimators=100,
                                               n_jobs=FLAGS.n_jobs,
                                               class_weight='balanced',
                                               verbose=1)
    elif FLAGS.model == 'svm':
        return grid_search.GridSearchCV(
            estimator=svm.SVC(kernel='rbf', gamma='auto',
                              class_weight='balanced'),
            param_grid={'C': np.logspace(-4, 4, 10)}, scoring='roc_auc',
            n_jobs=FLAGS.n_jobs, verbose=1)
    else:
        raise ValueError('Unrecognized model %s' % FLAGS.model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号