def fit(self, X, y, **kwargs):
# Determine output settings
n_samples, self.n_features_ = X.shape
if self.max_features is None:
self.max_features = 'auto'
y = np.atleast_1d(y)
if y.ndim == 1:
# reshape is necessary to preserve the data contiguity against vs
# [:, np.newaxis] that does not.
y = np.reshape(y, (-1, 1))
self.n_outputs_ = y.shape[1]
self.classes_ = [None] * self.n_outputs_
self.n_classes_ = [1] * self.n_outputs_
self.n_classes_ = np.array(self.n_classes_, dtype=np.intp)
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)
if len(y) != n_samples:
raise ValueError(
"Number of labels=%d does not match number of samples=%d"
% (len(y), n_samples))
# Build tree
self.tree_ = ExtraTree(
self.max_features, self.min_samples_split, self.n_classes_,
self.n_outputs_, self.classification)
self.tree_.build(X, y)
if self.n_outputs_ == 1:
self.n_classes_ = self.n_classes_[0]
self.classes_ = self.classes_[0]
return self
评论列表
文章目录