def _bilstm_score(self, logits, y, lens):
y_exp = y.unsqueeze(-1)
scores = torch.gather(logits, 2, y_exp).squeeze(-1)
mask = sequence_mask(lens).float()
scores = scores * mask
score = scores.sum(1).squeeze(-1)
return score
评论列表
文章目录