def sparse_boolean_mask(tensor, mask):
"""
Creates a sparse tensor from masked elements of `tensor`
Inputs:
tensor: a 2-D tensor, [batch_size, T]
mask: a 2-D mask, [batch_size, T]
Output: a 2-D sparse tensor
"""
mask_lens = tf.reduce_sum(tf.cast(mask, tf.int32), -1, keep_dims=True)
mask_shape = tf.shape(mask)
left_shifted_mask = tf.tile(
tf.expand_dims(tf.range(mask_shape[1]), 0),
[mask_shape[0], 1]
) < mask_lens
return tf.SparseTensor(
indices=tf.where(left_shifted_mask),
values=tf.boolean_mask(tensor, mask),
shape=tf.cast(tf.pack([mask_shape[0], tf.reduce_max(mask_lens)]), tf.int64) # For 2D only
)
评论列表
文章目录