def forward(self, pred, labels, targets):
indexer = labels.data - 1
prep = pred[:, indexer, :]
class_pred = torch.cat((torch.diag(prep[:, :, 0]).view(-1, 1),
torch.diag(prep[:, :, 1]).view(-1, 1)),
dim=1)
loss = self.smooth_l1_loss(class_pred.view(-1), targets.view(-1)) * 2
return loss
评论列表
文章目录