def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
"""If class ID is specified, filter all other classes.
Args:
labels: `int64` `Tensor` or `SparseTensor` with shape
[D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
target classes for the associated prediction. Commonly, N=1 and `labels`
has shape [batch_size, num_labels]. [D1, ... DN] must match
`predictions_idx`.
predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
where N >= 1. Commonly, N=1 and `predictions_idx` has shape
[batch size, k].
selected_id: Int id to select.
Returns:
Tuple of `labels` and `predictions_idx`, possibly with classes removed.
"""
if selected_id is None:
return labels, predictions_idx
return (_select_class_id(labels, selected_id),
_select_class_id(predictions_idx, selected_id))
评论列表
文章目录