rnn.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def forward(self, inputs, lengths=None, start_state=None):
        if not self._start_state_given:
            batch_size = inputs.size(0)
            start_hidden = self._lstm_start_hidden.unsqueeze(1).expand(2, batch_size, self._size).contiguous()
            start_state = self._lstm_start_state.unsqueeze(1).expand(2, batch_size, self._size).contiguous()
            start_state = (start_hidden, start_state)

        if lengths is not None:
            new_lengths, indices = torch.sort(lengths, dim=0, descending=True)
            inputs = torch.index_select(inputs, 0, indices)
            if self._start_state_given:
                start_state = (torch.index_select(start_state[0], 1, indices),
                               torch.index_select(start_state[1], 1, indices))
            new_lengths = [l.data[0] for l in new_lengths]
            inputs = nn.utils.rnn.pack_padded_sequence(inputs, new_lengths, batch_first=True)

        output, (h_n, c_n) = self._bilstm(inputs, start_state)

        if lengths is not None:
            output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0]
            _, back_indices = torch.sort(indices, dim=0)
            output = torch.index_select(output, 0, back_indices)
            h_n = torch.index_select(h_n, 1, back_indices)
            c_n = torch.index_select(c_n, 1, back_indices)

        return output, (h_n, c_n)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号