def extract_best_label_logits(self, arc_logits, label_logits, lengths):
pred_arcs = self.model.xp.argmax(arc_logits.data, axis=1)
label_logits = F.transpose(label_logits, (0, 2, 1, 3))
label_logits = [_logits[np.arange(_length), _arcs[:_length]]
for _logits, _arcs, _length
in zip(label_logits, pred_arcs, lengths)]
label_logits = F.pad_sequence(label_logits)
return label_logits
评论列表
文章目录