test_train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:hh-page-classifier 作者: TeamHG-Memex 项目源码 文件源码
def test_train_model():
    data = fetch_20newsgroups(
        random_state=42,
        categories=['sci.crypt', 'sci.electronics', 'sci.med', 'sci.space'])
    limit = 200
    if limit is not None:
        data['target'] = data['target'][:limit]
        data['data'] = data['data'][:limit]
    n_domains = int(len(data['target']) / 5)
    docs = [
        {
            'html': '\n'.join('<p>{}</p>'.format(t) for t in text.split('\n')),
            'url': 'http://example-{}.com/{}'.format(n % n_domains, n),
            'relevant': {'sci.space': True, 'sci.med': None}.get(
                data['target_names'][target], False),
        }
        for n, (text, target) in enumerate(zip(data['data'], data['target']))]
    result = train_model(docs)
    pprint(attr.asdict(result.meta))
    assert lst_as_dict(result.meta.advice) == [
        {'kind': 'Notice',
         'text': "The quality of the classifier is very good, ROC AUC is 0.96. "
                 "You can label more pages if you want to improve quality, "
                 "but it's better to start crawling "
                 "and check the quality of crawled pages.",
         },
        ]
    assert lst_as_dict(result.meta.description) == [
        {'heading': 'Dataset',
         'text': '200 documents, 159 labeled across 40 domains.'},
        {'heading': 'Class balance',
         'text': '33% relevant, 67% not relevant.'},
        {'heading': 'Metrics', 'text': ''},
        {'heading': 'Accuracy', 'text': '0.881 ± 0.122'},
        {'heading': 'ROC AUC', 'text': '0.964 ± 0.081'}]
    assert len(result.meta.weights['pos']) > 0
    assert len(result.meta.weights['neg']) > 0
    assert isinstance(result.model, BaseModel)
    assert hasattr(result.model, 'predict_proba')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号