random_forest.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:motif 作者: rabitt 项目源码 文件源码
def fit(self, X, Y):
        """ Train classifier.

        Parameters
        ----------
        X : np.array [n_samples, n_features]
            Training features.
        Y : np.array [n_samples]
            Training labels

        """
        x_shuffle, y_shuffle = shuffle(X, Y, random_state=self.random_state)
        clf_cv = RFC(n_estimators=self.n_estimators, n_jobs=self.n_jobs,
                     class_weight=self.class_weight,
                     random_state=self.random_state)
        param_dist = {
            "max_depth": sp_randint(1, 101),
            "max_features": [None, 'auto', 'sqrt', 'log2'],
            "min_samples_split": sp_randint(2, 11),
            "min_samples_leaf": sp_randint(1, 11),
            "bootstrap": [True, False],
            "criterion": ["gini", "entropy"]
        }

        random_search = RandomizedSearchCV(
            clf_cv, param_distributions=param_dist, refit=True,
            n_iter=self.n_iter_search, scoring='f1_weighted',
            random_state=self.random_state
        )
        random_search.fit(x_shuffle, y_shuffle)
        self.clf = random_search.best_estimator_
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号