pytorch_model.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:biaffineparser 作者: chantera 项目源码 文件源码
def compute_loss(self, y, t):
        arc_logits, label_logits = y
        true_arcs, true_labels = t.T

        b, l1, l2 = arc_logits.size()
        true_arcs = _model_var(
            self.model,
            pad_sequence(true_arcs, padding=-1, dtype=np.int64))
        arc_loss = F.cross_entropy(
            arc_logits.view(b * l1, l2), true_arcs.view(b * l1),
            ignore_index=-1)

        b, l1, d = label_logits.size()
        true_labels = _model_var(
            self.model,
            pad_sequence(true_labels, padding=-1, dtype=np.int64))
        label_loss = F.cross_entropy(
            label_logits.view(b * l1, d), true_labels.view(b * l1),
            ignore_index=-1)

        loss = arc_loss + label_loss
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号