def _viterbi_decode(self, feats):
backpointers = []
init_vvars = torch.Tensor(self.tagset_size, 1).fill_(-10000.).type(self.dtype)
init_vvars[self.tag_to_ix[self.START_TAG]][0] = 0
forward_var = autograd.Variable(init_vvars).type(self.dtype)
for feat in feats:
viterbi_vars, viterbi_idx = torch.max(self.transitions + torch.transpose(forward_var.expand(forward_var.size(0), self.tagset_size), 0, 1), 1)
forward_var = feat.view(self.tagset_size, 1) + viterbi_vars
backpointers.append(viterbi_idx)
terminal_var = forward_var + self.transitions[self.tag_to_ix[self.STOP_TAG]].view(self.tagset_size, 1)
_, best_tag_id = torch.max(terminal_var, 0, keepdim=True)
best_tag_id = to_scalar(best_tag_id)
path_score = terminal_var[best_tag_id]
best_path = [best_tag_id]
for bptrs_t in reversed(backpointers):
best_tag_id = to_scalar(bptrs_t[best_tag_id])
best_path.append(best_tag_id)
start = best_path.pop()
assert start == self.tag_to_ix[self.START_TAG] # Sanity check
best_path.reverse()
return path_score, best_path
评论列表
文章目录