rte_model.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Recognizing-Textual-Entailment 作者: codedecde 项目源码 文件源码
def _attention_forward(self, Y, mask_Y, h, r_tm1=None, index=None):
        '''
        Computes the Attention Weights over Y using h (and r_tm1 if given)
        Returns an attention weighted representation of Y, and the alphas
        inputs:
            Y : T x batch x n_dim
            mask_Y : T x batch
            h : batch x n_dim
            r_tm1 : batch x n_dim
            index : int : The timestep
        params:
            W_y : n_dim x n_dim
            W_h : n_dim x n_dim
            W_r : n_dim x n_dim
            W_alpha : n_dim x 1
        outputs :
            r = batch x n_dim
            alpha : batch x T
        '''
        Y = Y.transpose(1, 0)  # batch x T x n_dim
        mask_Y = mask_Y.transpose(1, 0)  # batch x T

        Wy = torch.bmm(Y, self.W_y.unsqueeze(0).expand(Y.size(0), *self.W_y.size()))  # batch x T x n_dim
        Wh = torch.mm(h, self.W_h)  # batch x n_dim
        if r_tm1 is not None:
            W_r_tm1 = self.batch_norm_r_r(torch.mm(r_tm1, self.W_r), index) if hasattr(self, 'batch_norm_r_r') else torch.mm(r_tm1, self.W_r)
            Wh = self.batch_norm_h_r(Wh, index) if hasattr(self, 'batch_norm_h_r') else Wh
            Wh += W_r_tm1
        M = torch.tanh(Wy + Wh.unsqueeze(1).expand(Wh.size(0), Y.size(1), Wh.size(1)))  # batch x T x n_dim
        alpha = torch.bmm(M, self.W_alpha.unsqueeze(0).expand(Y.size(0), *self.W_alpha.size())).squeeze(-1)  # batch x T
        alpha = alpha + (-1000.0 * (1. - mask_Y))  # To ensure probability mass doesn't fall on non tokens
        alpha = F.softmax(alpha)
        if r_tm1 is not None:
            r = torch.bmm(alpha.unsqueeze(1), Y).squeeze(1) + F.tanh(torch.mm(r_tm1, self.W_t))  # batch x n_dim
        else:
            r = torch.bmm(alpha.unsqueeze(1), Y).squeeze(1)  # batch x n_dim
        return r, alpha
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号