seq2seq.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:deep-text-corrector 作者: andabi 项目源码 文件源码
def forward(self, input, last_context, last_hidden, encoder_outputs):
        # input.size() = (B, 1), last_context.size() = (B, H), last_hidden.size() = (L, B, H), encoder_outputs.size() = (B, S, H)
        # word_embedded.size() = (B, 1, H)
        # print input.size()
        word_embedded = self.embedding(input)

        # rnn_input.size() = (B, 1, 2H), rnn_output.size() = (B, 1, H)
        # print word_embedded.size(), last_context.unsqueeze(1).size()
        rnn_input = torch.cat((word_embedded, last_context.unsqueeze(1)), -1)
        rnn_output, hidden = self.gru(rnn_input, last_hidden)
        rnn_output = rnn_output.squeeze(1)  # B x S=1 x H -> B x H

        # atten_weights.size() = (B, S)
        attn_weights = self.attn(rnn_output, encoder_outputs)
        context = attn_weights.unsqueeze(1).bmm(encoder_outputs).squeeze(1)  # B x H

        # TODO tanh?
        # Final output layer (next word prediction) using the RNN hidden state and context vector
        output = self.out(torch.cat((rnn_output, context), -1))  # B x V

        # Return final output, hidden state, and attention weights (for visualization)
        # output.size() = (B, V)
        return output, context, hidden, attn_weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号