models.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:syracuse_public 作者: dssg 项目源码 文件源码
def define_model(modelname):
    """
    Outputs model type and parameters

    Input
    ----
    model: str
       model type e.g., Logistic Regression
    parameters: ls
       hyperparameters of corresponding model

    Output
    ------
    clf: model object
       Model Object Classifier

    """
    if modelname == 'LR':
        return linear_model.LogisticRegression()
    elif modelname == 'NN':
        return neighbors.KNeighborsClassifier()
    elif modelname == 'DT':
        return tree.DecisionTreeClassifier()
    elif modelname == 'RF':
        return ensemble.RandomForestClassifier()
    elif modelname == 'NB':
        return naive_bayes.GaussianNB()
    elif modelname == 'SVM':
        return svm.SVC()
    elif modelname == 'ET':
        return ensemble.ExtraTreesClassifier()
    elif modelname == 'SGD':
        return linear_model.SGDClassifier()
    elif modelname == 'AB':
        return ensemble.AdaBoostClassifier(
            tree.DecisionTreeClassifier(max_depth=1)
        )
    elif modelname == 'GB':
        return ensemble.GradientBoostingClassifier()
    elif modelname == 'VC':
        return ensemble.VotingClassifier(estimators=[
            ('RFC', ensemble.RandomForestClassifier(n_estimators=10, max_depth=None, min_samples_split=1, random_state=0)), ('ETC', ensemble.ExtraTreesClassifier(max_depth=None, max_features=5, n_estimators=10, random_state=0, min_samples_split=1)), ('ABC', ensemble.AdaBoostClassifier())],
            voting='soft')
    elif modelname == 'VC2':
        return ensemble.VotingClassifier(estimators=[
            ('LR', linear_model.LogisticRegression(C=0.1, random_state=1)), ('RFC', ensemble.RandomForestClassifier(max_depth=None, n_estimators=10, random_state=0, min_samples_split=1)), ('ETC', ensemble.ExtraTreesClassifier(max_depth=None, max_features=5, n_estimators=10, random_state=0, min_samples_split=1))],
            voting='soft')

    else:
        raise ConfigError("Can't find the model: {}".format(model))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号