train.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:hh-page-classifier 作者: TeamHG-Memex 项目源码 文件源码
def build_folds(all_xs, all_ys, advice):
    domains = [get_domain(doc['url']) for doc in all_xs]
    n_domains = len(set(domains))
    n_relevant_domains = len(
        {domain for domain, is_relevant in zip(domains, all_ys) if is_relevant})
    n_folds = 4
    if n_relevant_domains == 1:
        advice.append(AdviceItem(
            WARNING,
            'Only 1 relevant domain in data means that it\'s impossible to do '
            'cross-validation across domains, '
            'and will likely result in model over-fitting.'
        ))
        folds = KFold(n_splits=n_folds).split(all_xs)
    else:
        folds = (GroupKFold(n_splits=min(n_domains, n_folds))
                 .split(all_xs, groups=domains))

    if 1 < n_relevant_domains < WARN_N_RELEVANT_DOMAINS:
        advice.append(AdviceItem(
            WARNING,
            'Low number of relevant domains (just {}) '
            'might result in model over-fitting.'.format(n_relevant_domains)
        ))
    folds = two_class_folds(folds, all_ys)
    if not folds:
        folds = two_class_folds(KFold(n_splits=n_folds).split(all_xs), all_ys)
    if not folds:
        advice.append(AdviceItem(
            WARNING,
            'Can not do cross-validation, as there are no folds where '
            'training data has both relevant and non-relevant examples. '
            'There are too few domains or the dataset is too unbalanced.'
        ))
    return folds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号