def test_train_model():
data = fetch_20newsgroups(
random_state=42,
categories=['sci.crypt', 'sci.electronics', 'sci.med', 'sci.space'])
limit = 200
if limit is not None:
data['target'] = data['target'][:limit]
data['data'] = data['data'][:limit]
n_domains = int(len(data['target']) / 5)
docs = [
{
'html': '\n'.join('<p>{}</p>'.format(t) for t in text.split('\n')),
'url': 'http://example-{}.com/{}'.format(n % n_domains, n),
'relevant': {'sci.space': True, 'sci.med': None}.get(
data['target_names'][target], False),
}
for n, (text, target) in enumerate(zip(data['data'], data['target']))]
result = train_model(docs)
pprint(attr.asdict(result.meta))
assert lst_as_dict(result.meta.advice) == [
{'kind': 'Notice',
'text': "The quality of the classifier is very good, ROC AUC is 0.96. "
"You can label more pages if you want to improve quality, "
"but it's better to start crawling "
"and check the quality of crawled pages.",
},
]
assert lst_as_dict(result.meta.description) == [
{'heading': 'Dataset',
'text': '200 documents, 159 labeled across 40 domains.'},
{'heading': 'Class balance',
'text': '33% relevant, 67% not relevant.'},
{'heading': 'Metrics', 'text': ''},
{'heading': 'Accuracy', 'text': '0.881 ± 0.122'},
{'heading': 'ROC AUC', 'text': '0.964 ± 0.081'}]
assert len(result.meta.weights['pos']) > 0
assert len(result.meta.weights['neg']) > 0
assert isinstance(result.model, BaseModel)
assert hasattr(result.model, 'predict_proba')
评论列表
文章目录