def num_relevant(labels, k):
"""Computes number of relevant values for each row in labels.
For labels with shape [D1, ... DN, num_labels], this is the minimum of
`num_labels` and `k`.
Args:
labels: `int64` `Tensor` or `SparseTensor` with shape
[D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
target classes for the associated prediction. Commonly, N=1 and `labels`
has shape [batch_size, num_labels].
k: Integer, k for @k metric.
Returns:
Integer `Tensor` of shape [D1, ... DN], where each value is the number of
relevant values for that row.
Raises:
ValueError: if inputs have invalid dtypes or values.
"""
if k < 1:
raise ValueError('Invalid k=%s.' % k)
with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
# For SparseTensor, calculate separate count for each row.
if isinstance(
labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
labels_sizes = set_ops.set_size(labels)
return math_ops.minimum(labels_sizes, k, name=scope)
# For dense Tensor, calculate scalar count based on last dimension, and
# tile across labels shape.
labels_shape = array_ops.shape(labels)
labels_size = labels_shape[-1]
num_relevant_scalar = math_ops.minimum(labels_size, k)
return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope)
metric_ops.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录