temperature_scaling.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:temperature_scaling 作者: gpleiss 项目源码 文件源码
def forward(self, logits, labels):
        softmaxes = F.softmax(logits)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = Variable(torch.zeros(1)).type_as(logits)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower) * confidences.le(bin_upper)
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.data[0] > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin- accuracy_in_bin) * prop_in_bin

        return ece
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号