def forward(self, input, last_context, last_hidden, encoder_outputs):
# input.size() = (B, 1), last_context.size() = (B, H), last_hidden.size() = (L, B, H), encoder_outputs.size() = (B, S, H)
# word_embedded.size() = (B, 1, H)
# print input.size()
word_embedded = self.embedding(input)
# rnn_input.size() = (B, 1, 2H), rnn_output.size() = (B, 1, H)
# print word_embedded.size(), last_context.unsqueeze(1).size()
rnn_input = torch.cat((word_embedded, last_context.unsqueeze(1)), -1)
rnn_output, hidden = self.gru(rnn_input, last_hidden)
rnn_output = rnn_output.squeeze(1) # B x S=1 x H -> B x H
# atten_weights.size() = (B, S)
attn_weights = self.attn(rnn_output, encoder_outputs)
context = attn_weights.unsqueeze(1).bmm(encoder_outputs).squeeze(1) # B x H
# TODO tanh?
# Final output layer (next word prediction) using the RNN hidden state and context vector
output = self.out(torch.cat((rnn_output, context), -1)) # B x V
# Return final output, hidden state, and attention weights (for visualization)
# output.size() = (B, V)
return output, context, hidden, attn_weights
评论列表
文章目录