def compute_loss(self, y, t):
arc_logits, label_logits = y
true_arcs, true_labels = t.T
b, l1, l2 = arc_logits.size()
true_arcs = _model_var(
self.model,
pad_sequence(true_arcs, padding=-1, dtype=np.int64))
arc_loss = F.cross_entropy(
arc_logits.view(b * l1, l2), true_arcs.view(b * l1),
ignore_index=-1)
b, l1, d = label_logits.size()
true_labels = _model_var(
self.model,
pad_sequence(true_labels, padding=-1, dtype=np.int64))
label_loss = F.cross_entropy(
label_logits.view(b * l1, d), true_labels.view(b * l1),
ignore_index=-1)
loss = arc_loss + label_loss
return loss
评论列表
文章目录