def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None): #pylint: disable=arguments-differ
if mask is not None:
tokens = tokens * mask.unsqueeze(-1).float()
# Our input has shape `(batch_size, num_tokens, embedding_dim)`, so we sum out the `num_tokens`
# dimension.
summed = tokens.sum(1)
if self._averaged:
if mask is not None:
lengths = get_lengths_from_binary_sequence_mask(mask)
length_mask = (lengths > 0)
# Set any length 0 to 1, to avoid dividing by zero.
lengths = torch.max(lengths, Variable(lengths.data.new().resize_(1).fill_(1)))
else:
lengths = Variable(tokens.data.new().resize_(1).fill_(tokens.size(1)), requires_grad=False)
length_mask = None
summed = summed / lengths.unsqueeze(-1).float()
if length_mask is not None:
summed = summed * (length_mask > 0).float().unsqueeze(-1)
return summed
评论列表
文章目录