sequence_labeling.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:NeuroNLP2 作者: XuezheMax 项目源码 文件源码
def loss(self, input_word, input_char, target, mask=None, length=None, hx=None, leading_symbolic=0):
        # [batch, length, num_labels]
        output, mask, length = self.forward(input_word, input_char, mask=mask, length=length, hx=hx)
        # [batch, length, num_labels]
        output = self.dense_softmax(output)
        # preds = [batch, length]
        _, preds = torch.max(output[:, :, leading_symbolic:], dim=2)
        preds += leading_symbolic

        output_size = output.size()
        # [batch * length, num_labels]
        output_size = (output_size[0] * output_size[1], output_size[2])
        output = output.view(output_size)

        if length is not None and target.size(1) != mask.size(1):
            max_len = length.max()
            target = target[:, :max_len].contiguous()

        if mask is not None:
            # TODO for Pytorch 2.0.4, first take nllloss then mask (no need of broadcast for mask)
            return self.nll_loss(self.logsoftmax(output) * mask.contiguous().view(output_size[0], 1),
                                 target.view(-1)) / mask.sum(), \
                   (torch.eq(preds, target).type_as(mask) * mask).sum(), preds
        else:
            num = output_size[0] * output_size[1]
            return self.nll_loss(self.logsoftmax(output), target.view(-1)) / num, \
                   (torch.eq(preds, target).type_as(output)).sum(), preds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号