metric_ops_test.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_three_labels_at_k5_some_out_of_range(self):
    """Tests that labels outside the [0, n_classes) range are ignored."""
    predictions = [
        [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
        [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]
    ]
    top_k_predictions = [
        [9, 4, 6, 2, 0],
        [5, 7, 2, 9, 6],
    ]
    sp_labels = tf.SparseTensorValue(
        indices=[[0, 0], [0, 1], [0, 2], [0, 3],
                 [1, 0], [1, 1], [1, 2], [1, 3]],
        # values -1 and 10 are outside the [0, n_classes) range and are ignored.
        values=np.array([2, 7, -1, 8,
                         1, 2, 5, 10], np.int64),
        shape=[2, 4])

    # Class 2: 2 labels, 2 correct predictions.
    self._test_streaming_sparse_precision_at_k(
        predictions, sp_labels, k=5, expected=2.0 / 2, class_id=2)
    self._test_streaming_sparse_precision_at_top_k(
        top_k_predictions, sp_labels, expected=2.0 / 2, class_id=2)

    # Class 5: 1 label, 1 correct prediction.
    self._test_streaming_sparse_precision_at_k(
        predictions, sp_labels, k=5, expected=1.0 / 1, class_id=5)
    self._test_streaming_sparse_precision_at_top_k(
        top_k_predictions, sp_labels, expected=1.0 / 1, class_id=5)

    # Class 7: 1 label, 1 incorrect prediction.
    self._test_streaming_sparse_precision_at_k(
        predictions, sp_labels, k=5, expected=0.0 / 1, class_id=7)
    self._test_streaming_sparse_precision_at_top_k(
        top_k_predictions, sp_labels, expected=0.0 / 1, class_id=7)

    # All classes: 10 predictions, 3 correct.
    self._test_streaming_sparse_precision_at_k(
        predictions, sp_labels, k=5, expected=3.0 / 10)
    self._test_streaming_sparse_precision_at_top_k(
        top_k_predictions, sp_labels, expected=3.0 / 10)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号