encoders.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:keita 作者: iwasaki-kenta 项目源码 文件源码
def forward(self, x):
        """
        A bidirectional RNN encoder. Has support for global max/average pooling.

        :param x: A tuple of Variable's representing padded sentence tensor batch
            [seq. length, batch size, embed. size] and sentence lengths.
        :return: Global max/average pooled embedding from bidirectional RNN encoder of [batch_size, hidden_size]
        """

        sentences, sentence_lengths = x

        # Sort sentences by descending length.
        sorted_sentence_lengths, sort_indices = torch.sort(sentence_lengths, dim=0, descending=True)
        _, unsort_indices = torch.sort(sort_indices, dim=0)

        sorted_sentence_lengths = sorted_sentence_lengths.data
        sorted_sentences = sentences.index_select(1, sort_indices)

        # Handle padding for RNN's.
        packed_sentences = nn.utils.rnn.pack_padded_sequence(sorted_sentences, sorted_sentence_lengths.clone().cpu().numpy())

        # [seq. length, sentence_batch size, 2 * num. layers * num. hidden]
        encoder_outputs = self.encoder(packed_sentences)[0]
        encoder_outputs = nn.utils.rnn.pad_packed_sequence(encoder_outputs)[0]

        # Unsort outputs.
        encoder_outputs = encoder_outputs.index_select(1, unsort_indices)

        # Apply global max/average pooling 1D.
        encoder_outputs = encoder_outputs.transpose(0, 2).transpose(0, 1)
        if self.pooling_mode == "max":
            encoder_outputs = F.max_pool1d(encoder_outputs, kernel_size=encoder_outputs.size(2))
        elif self.pooling_mode == "avg":
            encoder_outputs = F.avg_pool1d(encoder_outputs, kernel_size=encoder_outputs.size(2))

        encoder_outputs = encoder_outputs.squeeze()

        return encoder_outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号