def forward(self, output, target):
cross_entropy = F.cross_entropy(output, target)
cross_entropy_log = torch.log(cross_entropy)
focal_loss = -((1 - cross_entropy) ** self.focusing_param) * cross_entropy_log
balanced_focal_loss = self.balance_param * focal_loss
return balanced_focal_loss
评论列表
文章目录