def get_text_field_mask(text_field_tensors: Dict[str, torch.Tensor],
num_wrapping_dims: int = 0) -> torch.LongTensor:
"""
Takes the dictionary of tensors produced by a ``TextField`` and returns a mask
with 0 where the tokens are padding, and 1 otherwise. We also handle ``TextFields``
wrapped by an arbitrary number of ``ListFields``, where the number of wrapping ``ListFields``
is given by ``num_wrapping_dims``.
If ``num_wrapping_dims == 0``, the returned mask has shape ``(batch_size, num_tokens)``.
If ``num_wrapping_dims > 0`` then the returned mask has ``num_wrapping_dims`` extra
dimensions, so the shape will be ``(batch_size, ..., num_tokens)``.
There could be several entries in the tensor dictionary with different shapes (e.g., one for
word ids, one for character ids). In order to get a token mask, we use the tensor in
the dictionary with the lowest number of dimensions. After subtracting ``num_wrapping_dims``,
if this tensor has two dimensions we assume it has shape ``(batch_size, ..., num_tokens)``,
and use it for the mask. If instead it has three dimensions, we assume it has shape
``(batch_size, ..., num_tokens, num_features)``, and sum over the last dimension to produce
the mask. Most frequently this will be a character id tensor, but it could also be a
featurized representation of each token, etc.
NOTE: Our functions for generating masks create torch.LongTensors, because using
torch.ByteTensors inside Variables makes it easy to run into overflow errors
when doing mask manipulation, such as summing to get the lengths of sequences - see below.
>>> mask = torch.ones([260]).byte()
>>> mask.sum() # equals 260.
>>> var_mask = torch.autograd.Variable(mask)
>>> var_mask.sum() # equals 4, due to 8 bit precision - the sum overflows.
"""
tensor_dims = [(tensor.dim(), tensor) for tensor in text_field_tensors.values()]
tensor_dims.sort(key=lambda x: x[0])
smallest_dim = tensor_dims[0][0] - num_wrapping_dims
if smallest_dim == 2:
token_tensor = tensor_dims[0][1]
return (token_tensor != 0).long()
elif smallest_dim == 3:
character_tensor = tensor_dims[0][1]
return ((character_tensor > 0).long().sum(dim=-1) > 0).long()
else:
raise ValueError("Expected a tensor with dimension 2 or 3, found {}".format(smallest_dim))
评论列表
文章目录