def step(self, input_variable, target_variable, max_length):
teacher_forcing_ratio = 0.1
clip = 5.0
loss = 0 # Added onto for each word
self.encoder_optimizer.zero_grad()
self.decoder_optimizer.zero_grad()
input_length = input_variable.size()[0]
target_length = target_variable.size()[0]
encoder_hidden = self.encoder.init_hidden()
encoder_outputs, encoder_hidden = self.encoder(input_variable, encoder_hidden)
decoder_input = Variable(torch.LongTensor([[SOS_token]]))
decoder_context = Variable(torch.zeros(1, self.decoder.hidden_size))
decoder_hidden = encoder_hidden # Use last hidden state from encoder to start decoder
if USE_CUDA:
decoder_input = decoder_input.cuda()
decoder_context = decoder_context.cuda()
decoder_outputs = []
use_teacher_forcing = random.random() < teacher_forcing_ratio
use_teacher_forcing = True
if use_teacher_forcing:
for di in range(target_length):
decoder_output, decoder_context, decoder_hidden, decoder_attention = self.decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
loss += self.criterion(decoder_output, target_variable[di])
decoder_input = target_variable[di]
decoder_outputs.append(decoder_output.unsqueeze(0))
else:
for di in range(target_length):
decoder_output, decoder_context, decoder_hidden, decoder_attention = self.decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
loss += self.criterion(decoder_output, target_variable[di])
decoder_outputs.append(decoder_output.unsqueeze(0))
topv, topi = decoder_output.data.topk(1)
ni = topi[0][0]
decoder_input = Variable(torch.LongTensor([[ni]]))
if USE_CUDA: decoder_input = decoder_input.cuda()
if ni == EOS_token: break
loss.backward()
torch.nn.utils.clip_grad_norm(self.encoder.parameters(), clip)
torch.nn.utils.clip_grad_norm(self.decoder.parameters(), clip)
self.encoder_optimizer.step()
self.decoder_optimizer.step()
decoder_outputs = torch.cat(decoder_outputs, 0)
return loss.data[0] / target_length, decoder_outputs
评论列表
文章目录