crf.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:NeuroNLP2 作者: XuezheMax 项目源码 文件源码
def forward(self, input_h, input_c, mask=None):
        '''

        Args:
            input_h: Tensor
                the head input tensor with shape = [batch, length, input_size]
            input_c: Tensor
                the child input tensor with shape = [batch, length, input_size]
            mask: Tensor or None
                the mask tensor with shape = [batch, length]
            lengths: Tensor or None
                the length tensor with shape = [batch]

        Returns: Tensor
            the energy tensor with shape = [batch, num_label, length, length]

        '''
        batch, length, _ = input_h.size()
        # [batch, num_labels, length, length]
        output = self.attention(input_h, input_c, mask_d=mask, mask_e=mask)
        # set diagonal elements to -inf
        output = output + Variable(torch.diag(output.data.new(length).fill_(-np.inf)))
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号