models.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:AutoML4 作者: djajetic 项目源码 文件源码
def __init__(self, info, verbose=True, debug_mode=False):
        self.label_num=info['label_num']
        self.target_num=info['target_num']
        self.task = info['task']
        self.metric = info['metric']
        self.postprocessor = None
        #self.postprocessor = MultiLabelEnsemble(LogisticRegression(), balance=True) # To calibrate proba
        self.postprocessor = MultiLabelEnsemble(LogisticRegression(), balance=False) # To calibrate proba
        if debug_mode>=2:
            self.name = "RandomPredictor"
            self.model = RandomPredictor(self.target_num)
            self.predict_method = self.model.predict_proba 
            return
        if info['task']=='regression':
            if info['is_sparse']==True:
                self.name = "BaggingRidgeRegressor"
                self.model = BaggingRegressor(base_estimator=Ridge(), n_estimators=1, verbose=verbose) # unfortunately, no warm start...
            else:
                self.name = "GradientBoostingRegressor"
                self.model = GradientBoostingRegressor(n_estimators=1,  max_depth=4, min_samples_split=14, verbose=verbose, warm_start = True)
            self.predict_method = self.model.predict # Always predict probabilities
        else:
            if info['has_categorical']: # Out of lazziness, we do not convert categorical variables...
                self.name = "RandomForestClassifier"
                self.model = RandomForestClassifier(n_estimators=1, verbose=verbose) # unfortunately, no warm start...
            elif info['is_sparse']:                
                self.name = "BaggingNBClassifier"
                self.model = BaggingClassifier(base_estimator=BernoulliNB(), n_estimators=1, verbose=verbose) # unfortunately, no warm start...                          
            else:
                self.name = "GradientBoostingClassifier"
                self.model = eval(self.name + "(n_estimators=1, verbose=" + str(verbose) + ", random_state=1, warm_start = True)")
            if info['task']=='multilabel.classification':
                self.model = MultiLabelEnsemble(self.model)
            self.predict_method = self.model.predict_proba
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号