tree.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:extra-trees 作者: allrod5 项目源码 文件源码
def fit(self, X, y, **kwargs):
        # Determine output settings
        n_samples, self.n_features_ = X.shape
        if self.max_features is None:
            self.max_features = 'auto'

        y = np.atleast_1d(y)

        if y.ndim == 1:
            # reshape is necessary to preserve the data contiguity against vs
            # [:, np.newaxis] that does not.
            y = np.reshape(y, (-1, 1))

        self.n_outputs_ = y.shape[1]
        self.classes_ = [None] * self.n_outputs_
        self.n_classes_ = [1] * self.n_outputs_
        self.n_classes_ = np.array(self.n_classes_, dtype=np.intp)

        if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
            y = np.ascontiguousarray(y, dtype=DOUBLE)

        if len(y) != n_samples:
            raise ValueError(
                "Number of labels=%d does not match number of samples=%d"
                % (len(y), n_samples))

        # Build tree
        self.tree_ = ExtraTree(
            self.max_features, self.min_samples_split, self.n_classes_,
            self.n_outputs_, self.classification)
        self.tree_.build(X, y)

        if self.n_outputs_ == 1:
            self.n_classes_ = self.n_classes_[0]
            self.classes_ = self.classes_[0]

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号