metric_ops_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_three_labels_at_k5_some_out_of_range(self):
    """Tests that labels outside the [0, n_classes) count in denominator."""
    predictions = [
        [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
        [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
    sp_labels = tf.SparseTensorValue(
        indices=[[0, 0], [0, 1], [0, 2], [0, 3],
                 [1, 0], [1, 1], [1, 2], [1, 3]],
        # values -1 and 10 are outside the [0, n_classes) range.
        values=np.array([2, 7, -1, 8,
                         1, 2, 5, 10], np.int64),
        shape=[2, 4])

    # Class 2: 2 labels, both correct.
    self._test_streaming_sparse_recall_at_k(
        predictions=predictions, labels=sp_labels, k=5, expected=2.0 / 2,
        class_id=2)

    # Class 5: 1 label, incorrect.
    self._test_streaming_sparse_recall_at_k(
        predictions=predictions, labels=sp_labels, k=5, expected=1.0 / 1,
        class_id=5)

    # Class 7: 1 label, incorrect.
    self._test_streaming_sparse_recall_at_k(
        predictions=predictions, labels=sp_labels, k=5, expected=0.0 / 1,
        class_id=7)

    # All classes: 8 labels, 3 correct.
    self._test_streaming_sparse_recall_at_k(
        predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号