def viterbi_decode(self, logits, lens):
"""Borrowed from pytorch tutorial
Arguments:
logits: [batch_size, seq_len, n_labels] FloatTensor
lens: [batch_size] LongTensor
"""
batch_size, seq_len, n_labels = logits.size()
vit = logits.data.new(batch_size, self.n_labels).fill_(-10000)
vit[:, self.start_idx] = 0
vit = Variable(vit)
c_lens = lens.clone()
logits_t = logits.transpose(1, 0)
pointers = []
for logit in logits_t:
vit_exp = vit.unsqueeze(1).expand(batch_size, n_labels, n_labels)
trn_exp = self.transitions.unsqueeze(0).expand_as(vit_exp)
vit_trn_sum = vit_exp + trn_exp
vt_max, vt_argmax = vit_trn_sum.max(2)
vt_max = vt_max.squeeze(-1)
vit_nxt = vt_max + logit
pointers.append(vt_argmax.squeeze(-1).unsqueeze(0))
mask = (c_lens > 0).float().unsqueeze(-1).expand_as(vit_nxt)
vit = mask * vit_nxt + (1 - mask) * vit
mask = (c_lens == 1).float().unsqueeze(-1).expand_as(vit_nxt)
vit += mask * self.transitions[ self.stop_idx ].unsqueeze(0).expand_as(vit_nxt)
c_lens = c_lens - 1
pointers = torch.cat(pointers)
scores, idx = vit.max(1)
idx = idx.squeeze(-1)
paths = [idx.unsqueeze(1)]
for argmax in reversed(pointers):
idx_exp = idx.unsqueeze(-1)
idx = torch.gather(argmax, 1, idx_exp)
idx = idx.squeeze(-1)
paths.insert(0, idx.unsqueeze(1))
paths = torch.cat(paths[1:], 1)
scores = scores.squeeze(-1)
return scores, paths
评论列表
文章目录