seq2seq.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:deep-text-corrector 作者: andabi 项目源码 文件源码
def forward(self, hidden, encoder_outputs):
        # hidden.size() = (B, H), encoder_outputs.size() = (B, S, H)
        batch_size, encoder_outputs_len, _ = encoder_outputs.size()

        # Create variable to store attention energies
        # attn_energies.size() = (B, S)
        attn_energies = Variable(torch.zeros((batch_size, encoder_outputs_len)))  # B x S
        if Config.use_cuda: attn_energies = attn_energies.cuda()

        # Calculate energies for each encoder output
        # attn_energies.size() = (B, S)
        for i in range(encoder_outputs_len):
            attn_energies[:, i] = self.score(hidden, encoder_outputs[:, i])
            # print attn_energies[:, i]

        # Normalize energies to weights in range 0 to 1
        return F.softmax(attn_energies)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号