def forward(self, sequence, lengths, h, c):
sequence = nn.utils.rnn.pack_padded_sequence(sequence, lengths,
batch_first=True)
output, (h, c) = self.lstm(sequence, (h, c))
output, output_lengths = nn.utils.rnn.pad_packed_sequence(
output, batch_first=True)
output = self.hidden2out(output)
return output
评论列表
文章目录