def fit_thresholds(self, data, alpha, batch_size=128, verbose=0,
validation_data=None, cv=None, top_k=None):
inputs = np.hstack([data[k] for k in self._graph_inputs])
probs = self.predict(data, batch_size=batch_size)
targets = {k: data[k] for k in self._graph_outputs}
if isinstance(alpha, list):
if validation_data is None and cv is None:
warnings.warn("Neither validation data, nor the number of "
"cross-validation folds is provided. "
"The alpha parameter for threshold model will "
"be selected based on the default "
"cross-validation procedure in RidgeCV.")
elif validation_data is not None:
val_inputs = np.hstack([validation_data[k]
for k in self._graph_inputs])
val_probs = self.predict(validation_data)
val_targets = {k: validation_data[k]
for k in self._graph_outputs}
if verbose:
sys.stdout.write("Constructing thresholds.")
sys.stdout.flush()
self.t_models = {}
for k in self._graph_outputs:
if verbose:
sys.stdout.write(".")
sys.stdout.flush()
T = self._construct_thresholds(probs[k], targets[k])
if isinstance(alpha, list):
if validation_data is not None:
val_T = self._construct_thresholds(val_probs[k],
val_targets[k],
top_k=top_k)
score_best, alpha_best = -np.Inf, None
for a in alpha:
model = lm.Ridge(alpha=a).fit(inputs, T)
score = model.score(val_inputs, val_T)
if score > score_best:
score_best, alpha_best = score, a
alpha = alpha_best
else:
model = lm.RidgeCV(alphas=alpha, cv=cv).fit(inputs, T)
alpha = model.alpha_
self.t_models[k] = lm.Ridge(alpha=alpha)
self.t_models[k].fit(inputs, T)
if verbose:
sys.stdout.write("Done.\n")
sys.stdout.flush()
评论列表
文章目录