def forward(self, word_input, last_hidden, encoder_outputs):
# Note: we run this one step at a time
# TODO: FIX BATCHING
# Get the embedding of the current input word (last output word)
word_embedded = self.embedding(word_input).view(1, 1, -1) # S=1 x B x N
word_embedded = self.dropout(word_embedded)
# Calculate attention weights and apply to encoder outputs
attn_weights = self.attn(last_hidden[-1], encoder_outputs)
context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x 1 x N
context = context.transpose(0, 1) # 1 x B x N
# Combine embedded input word and attended context, run through RNN
rnn_input = torch.cat((word_embedded, context), 2)
output, hidden = self.gru(rnn_input, last_hidden)
# Final output layer
output = output.squeeze(0) # B x N
output = F.log_softmax(self.out(torch.cat((output, context), 1)))
# Return final output, hidden state, and attention weights (for visualization)
return output, hidden, attn_weights
seq2seq_batched_10.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录