def test_three_labels_at_k5_some_out_of_range(self):
"""Tests that labels outside the [0, n_classes) count in denominator."""
predictions = [
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
sp_labels = tf.SparseTensorValue(
indices=[[0, 0], [0, 1], [0, 2], [0, 3],
[1, 0], [1, 1], [1, 2], [1, 3]],
# values -1 and 10 are outside the [0, n_classes) range.
values=np.array([2, 7, -1, 8,
1, 2, 5, 10], np.int64),
shape=[2, 4])
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
predictions=predictions, labels=sp_labels, k=5, expected=2.0 / 2,
class_id=2)
# Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
predictions=predictions, labels=sp_labels, k=5, expected=1.0 / 1,
class_id=5)
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
predictions=predictions, labels=sp_labels, k=5, expected=0.0 / 1,
class_id=7)
# All classes: 8 labels, 3 correct.
self._test_streaming_sparse_recall_at_k(
predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8)
评论列表
文章目录