def compute_accuracy(self, y, t):
arc_logits, label_logits = y
true_arcs, true_labels = t.T
b, l1, l2 = arc_logits.size()
pred_arcs = arc_logits.data.max(2)[1].cpu()
true_arcs = pad_sequence(true_arcs, padding=-1, dtype=np.int64)
correct = pred_arcs.eq(true_arcs).cpu().sum()
arc_accuracy = (correct /
(b * l1 - np.sum(true_arcs.cpu().numpy() == -1)))
b, l1, d = label_logits.size()
pred_labels = label_logits.data.max(2)[1].cpu()
true_labels = pad_sequence(true_labels, padding=-1, dtype=np.int64)
correct = pred_labels.eq(true_labels).cpu().sum()
label_accuracy = (correct /
(b * l1 - np.sum(true_labels.cpu().numpy() == -1)))
accuracy = (arc_accuracy + label_accuracy) / 2
return accuracy
评论列表
文章目录