def extract_best_label_logits(self, arc_logits, label_logits, lengths):
pred_arcs = torch.squeeze(
torch.max(arc_logits, dim=1)[1], dim=1).data.cpu().numpy()
size = label_logits.size()
output_logits = _model_var(
self.model,
torch.zeros(size[0], size[1], size[3]))
for batch_index, (_logits, _arcs, _length) \
in enumerate(zip(label_logits, pred_arcs, lengths)):
for i in range(_length):
output_logits[batch_index] = _logits[_arcs[i]]
return output_logits
评论列表
文章目录