def get_lengths_from_binary_sequence_mask(mask: torch.Tensor):
"""
Compute sequence lengths for each batch element in a tensor using a
binary mask.
Parameters
----------
mask : torch.Tensor, required.
A 2D binary mask of shape (batch_size, sequence_length) to
calculate the per-batch sequence lengths from.
Returns
-------
A torch.LongTensor of shape (batch_size,) representing the lengths
of the sequences in the batch.
"""
return mask.long().sum(-1)
评论列表
文章目录