def sparse_boolean_mask(sparse_tensor, mask, name="sparse_boolean_mask"):
"""Boolean mask for `SparseTensor`s.
Args:
sparse_tensor: a `SparseTensor`.
mask: a 1D boolean dense`Tensor` whose length is equal to the 0th dimension
of `sparse_tensor`.
name: optional name for this operation.
Returns:
A `SparseTensor` that contains row `k` of `sparse_tensor` iff `mask[k]` is
`True`.
"""
# TODO(jamieas): consider mask dimension > 1 for symmetry with `boolean_mask`.
with ops.name_scope(name, values=[sparse_tensor, mask]):
mask = ops.convert_to_tensor(mask)
mask_rows = array_ops.where(mask)
first_indices = array_ops.squeeze(array_ops.slice(sparse_tensor.indices,
[0, 0], [-1, 1]))
# Identify indices corresponding to the rows identified by mask_rows.
sparse_entry_matches = functional_ops.map_fn(
lambda x: math_ops.equal(first_indices, x),
mask_rows,
dtype=dtypes.bool)
# Combine the rows of index_matches to form a mask for the sparse indices
# and values.
to_retain = array_ops.reshape(
functional_ops.foldl(math_ops.logical_or, sparse_entry_matches), [-1])
return sparse_ops.sparse_retain(sparse_tensor, to_retain)
评论列表
文章目录