pytorch_model.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:biaffineparser 作者: chantera 项目源码 文件源码
def compute_accuracy(self, y, t):
        arc_logits, label_logits = y
        true_arcs, true_labels = t.T

        b, l1, l2 = arc_logits.size()
        pred_arcs = arc_logits.data.max(2)[1].cpu()
        true_arcs = pad_sequence(true_arcs, padding=-1, dtype=np.int64)
        correct = pred_arcs.eq(true_arcs).cpu().sum()
        arc_accuracy = (correct /
                        (b * l1 - np.sum(true_arcs.cpu().numpy() == -1)))

        b, l1, d = label_logits.size()
        pred_labels = label_logits.data.max(2)[1].cpu()
        true_labels = pad_sequence(true_labels, padding=-1, dtype=np.int64)
        correct = pred_labels.eq(true_labels).cpu().sum()
        label_accuracy = (correct /
                          (b * l1 - np.sum(true_labels.cpu().numpy() == -1)))

        accuracy = (arc_accuracy + label_accuracy) / 2
        return accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号