def build_folds(all_xs, all_ys, advice):
domains = [get_domain(doc['url']) for doc in all_xs]
n_domains = len(set(domains))
n_relevant_domains = len(
{domain for domain, is_relevant in zip(domains, all_ys) if is_relevant})
n_folds = 4
if n_relevant_domains == 1:
advice.append(AdviceItem(
WARNING,
'Only 1 relevant domain in data means that it\'s impossible to do '
'cross-validation across domains, '
'and will likely result in model over-fitting.'
))
folds = KFold(n_splits=n_folds).split(all_xs)
else:
folds = (GroupKFold(n_splits=min(n_domains, n_folds))
.split(all_xs, groups=domains))
if 1 < n_relevant_domains < WARN_N_RELEVANT_DOMAINS:
advice.append(AdviceItem(
WARNING,
'Low number of relevant domains (just {}) '
'might result in model over-fitting.'.format(n_relevant_domains)
))
folds = two_class_folds(folds, all_ys)
if not folds:
folds = two_class_folds(KFold(n_splits=n_folds).split(all_xs), all_ys)
if not folds:
advice.append(AdviceItem(
WARNING,
'Can not do cross-validation, as there are no folds where '
'training data has both relevant and non-relevant examples. '
'There are too few domains or the dataset is too unbalanced.'
))
return folds
评论列表
文章目录