def fit_cv(self, data, labels, cv_params, epochs=10, **kwargs):
n_jobs = kwargs.get('n_jobs', 1)
iid = kwargs.get('iid', True)
refit = kwargs.get('refit', True)
cv = kwargs.get('cv', None)
verbose = kwargs.get('verbose', 0)
pre_dispatch = kwargs.get('pre_dispatch', '2*n_jobs')
error_score = kwargs.get('error_score', 'raise')
return_train_score = kwargs.get('return_train_score', True)
param_dct = self.get_params()
param_dct.update({'bootstrap_fraction': 1.0})
rscv = GridSearchCV(SGDBolasso(**param_dct),
scoring=make_scorer(accuracy_score),
verbose=verbose,
param_grid=cv_params,
fit_params={'epochs': 1, 'verbose': 0},
cv=cv,
return_train_score=return_train_score,
n_jobs=n_jobs,
iid=iid,
refit=refit,
pre_dispatch=pre_dispatch,
error_score=error_score)
rscv.fit(data, labels)
param_dct = rscv.best_params_.copy()
param_dct.update({'bootstrap_fraction': self.bootstrap_fraction})
best_estim = SGDBolasso(**param_dct)
best_estim.fit(data, labels, epochs=epochs)
return best_estim, rscv
评论列表
文章目录