def _score_sentence(self, input, tags):
bsz, sent_len, l_size = input.size()
score = Variable(self.torch.FloatTensor(bsz).fill_(0.))
s_score = Variable(self.torch.LongTensor([[START]]*bsz))
tags = torch.cat([s_score, tags], dim=-1)
input_t = input.transpose(0, 1)
for i, words in enumerate(input_t):
temp = self.transitions.index_select(1, tags[:, i])
bsz_t = gather_index(temp.transpose(0, 1), tags[:, i + 1])
w_step_score = gather_index(words, tags[:, i+1])
score = score + bsz_t + w_step_score
temp = self.transitions.index_select(1, tags[:, -1])
bsz_t = gather_index(temp.transpose(0, 1),
Variable(self.torch.LongTensor([STOP]*bsz)))
return score+bsz_t
评论列表
文章目录