lstm_attention.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:pytorch-seq2seq 作者: rowanz 项目源码 文件源码
def forward(self, dec_state, context, mask=None):
        """
        :param dec_state:  batch x dec_dim
        :param context: batch x T x enc_dim
        :return: Weighted context, batch x enc_dim
                 Alpha weights (viz), batch x T
        """
        batch, source_l, enc_dim = context.size()
        assert enc_dim == self.enc_dim

        # W*s over the entire batch (batch, attn_dim)
        dec_contrib = self.decoder_in(dec_state)

        # W*h over the entire length & batch (batch, source_l, attn_dim)
        enc_contribs = self.encoder_in(
            context.view(-1, self.enc_dim)).view(batch, source_l, self.attn_dim)

        # tanh( Wh*hj + Ws s_{i-1} )     (batch, source_l, dim)
        pre_attn = F.tanh(enc_contribs + dec_contrib.unsqueeze(1).expand_as(enc_contribs))

        # v^T*pre_attn for all batches/lengths (batch, source_l)
        energy = self.att_linear(pre_attn.view(-1, self.attn_dim)).view(batch, source_l)

        # Apply the mask. (Might be a better way to do this)
        if mask is not None:
            shift = energy.max(1)[0]
            energy_exp = (energy - shift.expand_as(energy)).exp() * mask
            alpha = torch.div(energy_exp, energy_exp.sum(1).expand_as(energy_exp))
        else:
            alpha = F.softmax(energy)

        weighted_context = torch.bmm(alpha.unsqueeze(1), context).squeeze(1)  # (batch, dim)
        return weighted_context, alpha
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号