def decode(self, input_word, input_char, target=None, mask=None, length=None, hx=None, leading_symbolic=0):
# output from rnn [batch, length, tag_space]
output, _, mask, length = self._get_rnn_output(input_word, input_char, mask=mask, length=length, hx=hx)
if target is None:
return self.crf.decode(output, mask=mask, leading_symbolic=leading_symbolic), None
if length is not None:
max_len = length.max()
target = target[:, :max_len]
preds = self.crf.decode(output, mask=mask, leading_symbolic=leading_symbolic)
if mask is None:
return preds, torch.eq(preds, target.data).float().sum()
else:
return preds, (torch.eq(preds, target.data).float() * mask.data).sum()
评论列表
文章目录