test_classification.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_precision_recall_f_extra_labels():
    # Test handling of explicit additional (not in input) labels to PRF
    y_true = [1, 3, 3, 2]
    y_pred = [1, 1, 3, 2]
    y_true_bin = label_binarize(y_true, classes=np.arange(5))
    y_pred_bin = label_binarize(y_pred, classes=np.arange(5))
    data = [(y_true, y_pred),
            (y_true_bin, y_pred_bin)]

    for i, (y_true, y_pred) in enumerate(data):
        # No average: zeros in array
        actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
                              average=None)
        assert_array_almost_equal([0., 1., 1., .5, 0.], actual)

        # Macro average is changed
        actual = recall_score(y_true, y_pred, labels=[0, 1, 2, 3, 4],
                              average='macro')
        assert_array_almost_equal(np.mean([0., 1., 1., .5, 0.]), actual)

        # No effect otheriwse
        for average in ['micro', 'weighted', 'samples']:
            if average == 'samples' and i == 0:
                continue
            assert_almost_equal(recall_score(y_true, y_pred,
                                             labels=[0, 1, 2, 3, 4],
                                             average=average),
                                recall_score(y_true, y_pred, labels=None,
                                             average=average))

    # Error when introducing invalid label in multilabel case
    # (although it would only affect performance if average='macro'/None)
    for average in [None, 'macro', 'micro', 'samples']:
        assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
                      labels=np.arange(6), average=average)
        assert_raises(ValueError, recall_score, y_true_bin, y_pred_bin,
                      labels=np.arange(-1, 4), average=average)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号