seq2seq_batched_10.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Seq2Seq-on-Word-Sense-Disambiguition 作者: lbwbowenLi 项目源码 文件源码
def forward(self, hidden, encoder_outputs):
        max_len = encoder_outputs.size(0)
        this_batch_size = encoder_outputs.size(1)

        # Create variable to store attention energies
        attn_energies = Variable(torch.zeros(this_batch_size, max_len)) # B x S

        if USE_CUDA:
            attn_energies = attn_energies.cuda()

        # For each batch of encoder outputs
        for b in range(this_batch_size):
            # Calculate energy for each encoder output
            for i in range(max_len):
                attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))

        # Normalize energies to weights in range 0 to 1, resize to 1 x B x S
        return F.softmax(attn_energies).unsqueeze(1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号