def cross_vote(self, sequences, bin_sites, fit_batch_size=500,
pre_batch_size=200, max_splits=100000,
active_learning=False, random_state=1234, n_jobs=-1):
"""2-fold cross fit and vote."""
votes = dict()
part1, part2 = balanced_split(sequences, bin_sites, n_splits=2,
random_state=random_state)
part1, part1_ = tee(part1)
part2, part2_ = tee(part2)
# fold 1
logger.debug("Fold 1")
tr, te = part1, part2
self._fit(tr, bin_sites, fit_batch_size, max_splits, active_learning,
random_state, n_jobs)
part_votes = self.vote(
te, pre_batch_size, max_splits, random_state, n_jobs)
votes.update(part_votes)
# fold 2
logger.debug("Fold 2")
tr, te = part2_, part1_
self._fit(tr, bin_sites, fit_batch_size, max_splits, active_learning,
random_state, n_jobs)
part_votes = self.vote(
te, pre_batch_size, max_splits, random_state, n_jobs)
votes.update(part_votes)
return votes
评论列表
文章目录