def transition_score(self, labels, lens):
"""
Arguments:
labels: [batch_size, seq_len] LongTensor
lens: [batch_size] LongTensor
"""
batch_size, seq_len = labels.size()
# pad labels with <start> and <stop> indices
labels_ext = Variable(labels.data.new(batch_size, seq_len + 2))
labels_ext[:, 0] = self.start_idx
labels_ext[:, 1:-1] = labels
mask = sequence_mask(lens + 1, max_len=seq_len + 2).long()
pad_stop = Variable(labels.data.new(1).fill_(self.stop_idx))
pad_stop = pad_stop.unsqueeze(-1).expand(batch_size, seq_len + 2)
labels_ext = (1 - mask) * pad_stop + mask * labels_ext
labels = labels_ext
trn = self.transitions
# obtain transition vector for each label in batch and timestep
# (except the last ones)
trn_exp = trn.unsqueeze(0).expand(batch_size, *trn.size())
lbl_r = labels[:, 1:]
lbl_rexp = lbl_r.unsqueeze(-1).expand(*lbl_r.size(), trn.size(0))
trn_row = torch.gather(trn_exp, 1, lbl_rexp)
# obtain transition score from the transition vector for each label
# in batch and timestep (except the first ones)
lbl_lexp = labels[:, :-1].unsqueeze(-1)
trn_scr = torch.gather(trn_row, 2, lbl_lexp)
trn_scr = trn_scr.squeeze(-1)
mask = sequence_mask(lens + 1).float()
trn_scr = trn_scr * mask
score = trn_scr.sum(1).squeeze(-1)
return score
评论列表
文章目录