class_heatmap.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:unblackboxing_webinar 作者: deepsense-ai 项目源码 文件源码
def plot(self, vis_func, img_path, label_list, figsize):
        img = utils.load_img(img_path, target_size=self.img_shape_)
        img = img[:,:,:3]

        predictions = self.model_.predict(img2tensor(img, self.img_shape_))
        predictions = softmax(predictions)

        if not label_list:
            prediction_text = decode_predictions(predictions)[0]
            def _plot(label_id):
                label_id = int(label_id)
                text_label = get_pred_text_label(label_id)
                label_proba = np.round(predictions[0,label_id], 4)
                heatmap = vis_func(img, label_id)
                for p in prediction_text:
                    print(p[1:]) 

                plt.figure(figsize=figsize)
                plt.subplot(1,2,1)
                plt.title('label:%s\nscore:%s'%(text_label,label_proba))
                plt.imshow(overlay(heatmap, img))
                plt.subplot(1,2,2)
                plt.imshow(img)
                plt.show()
        else:
            def _plot(label_id):
                print(pd.DataFrame(predictions, columns=label_list))
                label_id = int(label_id)
                text_label = label_list[label_id]
                label_proba = np.round(predictions[0,label_id], 4)
                heatmap = vis_func(img,label_id)

                plt.figure(figsize=figsize)
                plt.subplot(1,2,1)
                plt.title('label:%s\nscore:%s'%(text_label,label_proba))
                plt.imshow(overlay(heatmap, img))
                plt.subplot(1,2,2)
                plt.imshow(img)
                plt.show()       

        return interact(_plot, label_id='1')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号