def loss(self, input_word, input_char, target, mask=None, length=None, hx=None, leading_symbolic=0):
# [batch, length, num_labels]
output, mask, length = self.forward(input_word, input_char, mask=mask, length=length, hx=hx)
# [batch, length, num_labels]
output = self.dense_softmax(output)
# preds = [batch, length]
_, preds = torch.max(output[:, :, leading_symbolic:], dim=2)
preds += leading_symbolic
output_size = output.size()
# [batch * length, num_labels]
output_size = (output_size[0] * output_size[1], output_size[2])
output = output.view(output_size)
if length is not None and target.size(1) != mask.size(1):
max_len = length.max()
target = target[:, :max_len].contiguous()
if mask is not None:
# TODO for Pytorch 2.0.4, first take nllloss then mask (no need of broadcast for mask)
return self.nll_loss(self.logsoftmax(output) * mask.contiguous().view(output_size[0], 1),
target.view(-1)) / mask.sum(), \
(torch.eq(preds, target).type_as(mask) * mask).sum(), preds
else:
num = output_size[0] * output_size[1]
return self.nll_loss(self.logsoftmax(output), target.view(-1)) / num, \
(torch.eq(preds, target).type_as(output)).sum(), preds
评论列表
文章目录