def fit(self, X, y):
labels = list(set(y))
if len(labels) != 2:
raise Exception("A binary setup is required")
min_count = X.shape[0]
self._min_label = None
for label in labels:
count = list(y).count(label)
if count <= min_count:
min_count = count
self._min_label = label
if self._reference_label is None:
self._reference_label = self._min_label
if not self._reference_label in labels:
raise Exception("Reference label does not appear in training data")
if min_count >= self._n_folds:
cv = cross_validation.StratifiedKFold(y, n_folds=min(X.shape[0], self._n_folds), shuffle=True,
random_state=self._seed)
else:
cv = cross_validation.KFold(X.shape[0], n_folds=min(X.shape[0], self._n_folds), shuffle=True,
random_state=self._seed)
tp = 0
fp = 0
ptp = 0
pfn = 0
pfp = 0
ptn = 0
pool = Pool(processes=10)
requests = list()
for train_cv, test_cv in cv:
requests.append((X, y, train_cv, test_cv))
for tp, fp, ptp, pfn, pfp, ptn in pool.map(self._fit_fold, requests):
tp += tp
fp += fp
ptp += ptp
pfn += ptn
pfp += pfp
ptn += ptn
pool.close()
positives = min_count
negatives = X.shape[0] - positives
self._tpr = tp / positives
self._fpr = fp / negatives
self._ptpr = ptp / (ptp + pfn)
self._pfpr = pfp / (pfp + ptn)
self._clf.fit(X, y)
if self._clf.classes_[0] == self._min_label:
self._pos_idx = 0
self._neg_idx = 1
else:
self._neg_idx = 0
self._pos_idx = 1
评论列表
文章目录