def forward(self, input):
bsz, sent_len, l_size = input.size()
init_alphas = self.torch.FloatTensor(bsz, self.label_size).fill_(-10000.)
init_alphas[:, START].fill_(0.)
forward_var = Variable(init_alphas)
input_t = input.transpose(0, 1)
for words in input_t:
alphas_t = []
for next_tag in range(self.label_size):
emit_score = words[:, next_tag].contiguous()
emit_score = emit_score.unsqueeze(1).expand_as(words)
trans_score = self.transitions[next_tag, :].view(1, -1).expand_as(words)
next_tag_var = forward_var + trans_score + emit_score
alphas_t.append(log_sum_exp(next_tag_var, True))
forward_var = torch.cat(alphas_t, dim=-1)
return log_sum_exp(forward_var)
评论列表
文章目录