def _select_class_id(ids, selected_id):
"""Filter all but `selected_id` out of `ids`.
Args:
ids: `int64` `Tensor` or `SparseTensor` of IDs.
selected_id: Int id to select.
Returns:
`SparseTensor` of same dimensions as `ids`, except for the last dimension,
which might be smaller. This contains only the entries equal to
`selected_id`.
"""
if isinstance(ids, (ops.SparseTensor, ops.SparseTensorValue)):
return sparse_ops.sparse_retain(
ids, math_ops.equal(ids.values, selected_id))
# TODO(ptucker): Make this more efficient, maybe add a sparse version of
# tf.equal and tf.reduce_any?
# Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
ids_shape = array_ops.shape(ids)
ids_last_dim = array_ops.size(ids_shape) - 1
filled_selected_id_shape = math_ops.reduced_shape(
ids_shape, array_ops.reshape(ids_last_dim, [1]))
# Intersect `ids` with the selected ID.
filled_selected_id = array_ops.fill(
filled_selected_id_shape, math_ops.to_int64(selected_id))
return set_ops.set_intersection(filled_selected_id, ids)
评论列表
文章目录