def remove_sentence_boundaries(tensor: torch.Tensor,
mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Remove begin/end of sentence embeddings from the batch of sentences.
Given a batch of sentences with size ``(batch_size, timesteps, dim)``
this returns a tensor of shape ``(batch_size, timesteps - 2, dim)`` after removing
the beginning and end sentence markers. The sentences are assumed to be padded on the right,
with the beginning of each sentence assumed to occur at index 0 (i.e., ``mask[:, 0]`` is assumed
to be 1).
Returns both the new tensor and updated mask.
This function is the inverse of ``add_sentence_boundary_token_ids``.
Parameters
----------
tensor : ``torch.Tensor``
A tensor of shape ``(batch_size, timesteps, dim)``
mask : ``torch.Tensor``
A tensor of shape ``(batch_size, timesteps)``
Returns
-------
tensor_without_boundary_tokens : ``torch.Tensor``
The tensor after removing the boundary tokens of shape ``(batch_size, timesteps - 2, dim)``
new_mask : ``torch.Tensor``
The new mask for the tensor of shape ``(batch_size, timesteps - 2)``.
"""
# TODO: matthewp, profile this transfer
sequence_lengths = mask.sum(dim=1).data.cpu().numpy()
tensor_shape = list(tensor.data.shape)
new_shape = list(tensor_shape)
new_shape[1] = tensor_shape[1] - 2
tensor_without_boundary_tokens = Variable(tensor.data.new(*new_shape).fill_(0))
new_mask = Variable(tensor.data.new(new_shape[0], new_shape[1]).fill_(0)).long()
for i, j in enumerate(sequence_lengths):
if j > 2:
tensor_without_boundary_tokens[i, :(j - 2), :] = tensor[i, 1:(j - 1), :]
new_mask[i, :(j - 2)] = 1
return tensor_without_boundary_tokens, new_mask
评论列表
文章目录