def forward(self, x):
"""
A bidirectional RNN encoder. Has support for global max/average pooling.
:param x: A tuple of Variable's representing padded sentence tensor batch
[seq. length, batch size, embed. size] and sentence lengths.
:return: Global max/average pooled embedding from bidirectional RNN encoder of [batch_size, hidden_size]
"""
sentences, sentence_lengths = x
# Sort sentences by descending length.
sorted_sentence_lengths, sort_indices = torch.sort(sentence_lengths, dim=0, descending=True)
_, unsort_indices = torch.sort(sort_indices, dim=0)
sorted_sentence_lengths = sorted_sentence_lengths.data
sorted_sentences = sentences.index_select(1, sort_indices)
# Handle padding for RNN's.
packed_sentences = nn.utils.rnn.pack_padded_sequence(sorted_sentences, sorted_sentence_lengths.clone().cpu().numpy())
# [seq. length, sentence_batch size, 2 * num. layers * num. hidden]
encoder_outputs = self.encoder(packed_sentences)[0]
encoder_outputs = nn.utils.rnn.pad_packed_sequence(encoder_outputs)[0]
# Unsort outputs.
encoder_outputs = encoder_outputs.index_select(1, unsort_indices)
# Apply global max/average pooling 1D.
encoder_outputs = encoder_outputs.transpose(0, 2).transpose(0, 1)
if self.pooling_mode == "max":
encoder_outputs = F.max_pool1d(encoder_outputs, kernel_size=encoder_outputs.size(2))
elif self.pooling_mode == "avg":
encoder_outputs = F.avg_pool1d(encoder_outputs, kernel_size=encoder_outputs.size(2))
encoder_outputs = encoder_outputs.squeeze()
return encoder_outputs
评论列表
文章目录