def test_precision_recall_f_ignored_labels():
# Test a subset of labels may be requested for PRF
y_true = [1, 1, 2, 3]
y_pred = [1, 3, 3, 3]
y_true_bin = label_binarize(y_true, classes=np.arange(5))
y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
data = [(y_true, y_pred),
(y_true_bin, y_pred_bin)]
for i, (y_true, y_pred) in enumerate(data):
recall_13 = partial(recall_score, y_true, y_pred, labels=[1, 3])
recall_all = partial(recall_score, y_true, y_pred, labels=None)
assert_array_almost_equal([.5, 1.], recall_13(average=None))
assert_almost_equal((.5 + 1.) / 2, recall_13(average='macro'))
assert_almost_equal((.5 * 2 + 1. * 1) / 3,
recall_13(average='weighted'))
assert_almost_equal(2. / 3, recall_13(average='micro'))
# ensure the above were meaningful tests:
for average in ['macro', 'weighted', 'micro']:
assert_not_equal(recall_13(average=average),
recall_all(average=average))
评论列表
文章目录