def forward(self, tokens: torch.Tensor, mask: torch.Tensor): # pylint: disable=arguments-differ
if mask is not None:
tokens = tokens * mask.unsqueeze(-1).float()
# Our input is expected to have shape `(batch_size, num_tokens, embedding_dim)`. The
# convolution layers expect input of shape `(batch_size, in_channels, sequence_length)`,
# where the conv layer `in_channels` is our `embedding_dim`. We thus need to transpose the
# tensor first.
tokens = torch.transpose(tokens, 1, 2)
# Each convolution layer returns output of size `(batch_size, num_filters, pool_length)`,
# where `pool_length = num_tokens - ngram_size + 1`. We then do an activation function,
# then do max pooling over each filter for the whole input sequence. Because our max
# pooling is simple, we just use `torch.max`. The resultant tensor of has shape
# `(batch_size, num_conv_layers * num_filters)`, which then gets projected using the
# projection layer, if requested.
filter_outputs = [self._activation(convolution_layer(tokens)).max(dim=2)[0]
for convolution_layer in self._convolution_layers]
# Now we have a list of `num_conv_layers` tensors of shape `(batch_size, num_filters)`.
# Concatenating them gives us a tensor of shape `(batch_size, num_filters * num_conv_layers)`.
maxpool_output = torch.cat(filter_outputs, dim=1) if len(filter_outputs) > 1 else filter_outputs[0]
if self.projection_layer:
result = self.projection_layer(maxpool_output)
else:
result = maxpool_output
return result
评论列表
文章目录