utils_feature_selection.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:auto_ml 作者: doordash 项目源码 文件源码
def fit(self, X, y=None):


        self.selector = get_feature_selection_model_from_name(self.type_of_estimator, self.feature_selection_model)

        if self.selector == 'KeepAll':
            if scipy.sparse.issparse(X):
                num_cols = X.shape[0]
            else:
                num_cols = len(X[0])

            self.support_mask = [True for col_idx in range(num_cols) ]
        else:
            if self.feature_selection_model == 'SelectFromModel':
                num_cols = X.shape[1]
                num_rows = X.shape[0]
                if self.type_of_estimator == 'regressor':
                    self.estimator = RandomForestRegressor(n_jobs=-1, max_depth=10, n_estimators=15)
                else:
                    self.estimator = RandomForestClassifier(n_jobs=-1, max_depth=10, n_estimators=15)

                self.estimator.fit(X, y)

                feature_importances = self.estimator.feature_importances_

                # Two ways of doing feature selection

                # 1. Any feature with a feature importance of at least 1/100th of our max feature
                max_feature_importance = max(feature_importances)
                threshold_by_relative_importance = 0.01 * max_feature_importance

                # 2. 1/4 the number of rows (so 100 rows means 25 columns)
                sorted_importances = sorted(feature_importances, reverse=True)
                max_cols = int(num_rows * 0.25)
                try:
                    threshold_by_max_cols = sorted_importances[max_cols]
                except IndexError:
                    threshold_by_max_cols = sorted_importances[-1]

                threshold = max(threshold_by_relative_importance, threshold_by_max_cols)
                self.support_mask = [True if x > threshold else False for x in feature_importances]

            else:
                self.selector.fit(X, y)
                self.support_mask = self.selector.get_support()

        # Get a mask of which indices it is we want to keep
        self.index_mask = [idx for idx, val in enumerate(self.support_mask) if val == True]
        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号