metrics.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def meaniou(self, predictor, predict_dir, image_size):
        segparams = util.SegParams()
        classes = segparams.feature_classes().values()
        num_classes = len(classes) + 1
        hist = np.zeros((num_classes, num_classes))
        image_names = [filename.strip() for filename in os.listdir(
            predict_dir) if filename.endswith('.jpg')]
        for image_filename in image_names:
            final_prediction_map = predictor.predict(
                os.path.join(predict_dir, image_filename))
            final_prediction_map = final_prediction_map.transpose(
                0, 2, 1).squeeze()
            gt_name = os.path.join(predict_dir,
                                   image_filename[:-4] + '_final_mask' + '.png')
            gt = convert(gt_name, image_size)
            gt = np.asarray(gt)
            gt = convert_labels(gt, image_size, image_size)
            hist += compute_hist(gt, final_prediction_map,
                                 num_classes=num_classes)
        iou = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
        meaniou = np.nanmean(iou)

        return meaniou
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号