attention.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:NeuroNLP2 作者: XuezheMax 项目源码 文件源码
def forward(self, input_d, input_e, mask_d=None, mask_e=None):
        '''

        Args:
            input_d: Tensor
                the decoder input tensor with shape = [batch, length_decoder, input_size]
            input_e: Tensor
                the child input tensor with shape = [batch, length_encoder, input_size]
            mask_d: Tensor or None
                the mask tensor for decoder with shape = [batch, length_decoder]
            mask_e: Tensor or None
                the mask tensor for encoder with shape = [batch, length_encoder]

        Returns: Tensor
            the energy tensor with shape = [batch, num_label, length, length]

        '''
        assert input_d.size(0) == input_e.size(0), 'batch sizes of encoder and decoder are requires to be equal.'
        batch, length_decoder, _ = input_d.size()
        _, length_encoder, _ = input_e.size()

        # compute decoder part: [batch, length_decoder, input_size_decoder] * [input_size_decoder, hidden_size]
        # the output shape is [batch, length_decoder, hidden_size]
        # then --> [batch, 1, length_decoder, hidden_size]
        out_d = torch.matmul(input_d, self.W_d).unsqueeze(1)
        # compute decoder part: [batch, length_encoder, input_size_encoder] * [input_size_encoder, hidden_size]
        # the output shape is [batch, length_encoder, hidden_size]
        # then --> [batch, length_encoder, 1, hidden_size]
        out_e = torch.matmul(input_e, self.W_e).unsqueeze(2)

        # add them together [batch, length_encoder, length_decoder, hidden_size]
        out = F.tanh(out_d + out_e + self.b)

        # product with v
        # [batch, length_encoder, length_decoder, hidden_size] * [hidden, num_label]
        # [batch, length_encoder, length_decoder, num_labels]
        # then --> [batch, num_labels, length_decoder, length_encoder]
        return torch.matmul(out, self.v).transpose(1, 3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号