def fit(self, X, *args, **kwargs):
if self._grid_search:
model = GridSearchCV(self._model, **self._grid_search)
elif self._random_search:
model = RandomizedSearchCV(self._model, **self._random_search)
else:
model = self._model
if self._grid_search is not None:
self._grid = model
elif self._random_search is not None:
self._rnd = model
assert (self.target in X.columns.values), 'X must contain the target column'
self._xcols = list(X.columns.values)
self._xcols.remove(self.target)
if len(self._columns_exclude) == 0 and len(self._columns_include) > 0:
self._columns_exclude = list(set(self._xcols) - set(self._columns_include))
[self._xcols.remove(t) for t in self._columns_exclude]
x = X[self._xcols]
y = X[self.target]
model.fit(x, y, **kwargs)
return self
评论列表
文章目录