def _last_dimension_applicator(function_to_apply: Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
tensor: torch.Tensor,
mask: Optional[torch.Tensor] = None):
"""
Takes a tensor with 3 or more dimensions and applies a function over the last dimension. We
assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given)
has shape ``(batch_size, sequence_length)``. We first unsqueeze and expand the mask so that it
has the same shape as the tensor, then flatten them both to be 2D, pass them through
the function and put the tensor back in its original shape.
"""
tensor_shape = tensor.size()
reshaped_tensor = tensor.view(-1, tensor.size()[-1])
if mask is not None:
while mask.dim() < tensor.dim():
mask = mask.unsqueeze(1)
mask = mask.expand_as(tensor).contiguous().float()
mask = mask.view(-1, mask.size()[-1])
reshaped_result = function_to_apply(reshaped_tensor, mask)
return reshaped_result.view(*tensor_shape)
评论列表
文章目录