def plot(self, vis_func, img_path, label_list, figsize):
img = utils.load_img(img_path, target_size=self.img_shape_)
img = img[:,:,:3]
predictions = self.model_.predict(img2tensor(img, self.img_shape_))
predictions = softmax(predictions)
if not label_list:
prediction_text = decode_predictions(predictions)[0]
def _plot(label_id):
label_id = int(label_id)
text_label = get_pred_text_label(label_id)
label_proba = np.round(predictions[0,label_id], 4)
heatmap = vis_func(img, label_id)
for p in prediction_text:
print(p[1:])
plt.figure(figsize=figsize)
plt.subplot(1,2,1)
plt.title('label:%s\nscore:%s'%(text_label,label_proba))
plt.imshow(overlay(heatmap, img))
plt.subplot(1,2,2)
plt.imshow(img)
plt.show()
else:
def _plot(label_id):
print(pd.DataFrame(predictions, columns=label_list))
label_id = int(label_id)
text_label = label_list[label_id]
label_proba = np.round(predictions[0,label_id], 4)
heatmap = vis_func(img,label_id)
plt.figure(figsize=figsize)
plt.subplot(1,2,1)
plt.title('label:%s\nscore:%s'%(text_label,label_proba))
plt.imshow(overlay(heatmap, img))
plt.subplot(1,2,2)
plt.imshow(img)
plt.show()
return interact(_plot, label_id='1')
class_heatmap.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录