def infer(self, input_variable):
input_length = input_variable.size()[0]
encoder_hidden = self.encoder.init_hidden()
encoder_outputs, encoder_hidden = self.encoder(input_variable, encoder_hidden)
decoder_input = Variable(torch.LongTensor([[SOS_token]]))
decoder_context = Variable(torch.zeros(1, self.decoder.hidden_size))
decoder_hidden = encoder_hidden
if USE_CUDA:
decoder_input = decoder_input.cuda()
decoder_context = decoder_context.cuda()
decoder_outputs = []
for i in range(self.max_length):
decoder_output, decoder_context, decoder_hidden, decoder_attention = self.decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
decoder_outputs.append(decoder_output.unsqueeze(0))
topv, topi = decoder_output.data.topk(1)
ni = topi[0][0]
decoder_input = Variable(torch.LongTensor([[ni]])) # Chosen word is next input
if USE_CUDA: decoder_input = decoder_input.cuda()
if ni == EOS_token: break
decoder_outputs = torch.cat(decoder_outputs, 0)
return decoder_outputs
评论列表
文章目录