def forward(self, logits, labels):
softmaxes = F.softmax(logits)
confidences, predictions = torch.max(softmaxes, 1)
accuracies = predictions.eq(labels)
ece = Variable(torch.zeros(1)).type_as(logits)
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
# Calculated |confidence - accuracy| in each bin
in_bin = confidences.gt(bin_lower) * confidences.le(bin_upper)
prop_in_bin = in_bin.float().mean()
if prop_in_bin.data[0] > 0:
accuracy_in_bin = accuracies[in_bin].float().mean()
avg_confidence_in_bin = confidences[in_bin].mean()
ece += torch.abs(avg_confidence_in_bin- accuracy_in_bin) * prop_in_bin
return ece
temperature_scaling.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录