def plot_heatmaps(img_arr, img_names, titles, heatmaps, labels, out_dir):
# construct cmap
pal = sns.diverging_palette(240, 10, n=30, center="dark")
my_cmap = ListedColormap(sns.color_palette(pal).as_hex())
min_val, max_val = np.min(heatmaps), np.max(heatmaps)
for j, (img, img_name, h_map, title, y) in enumerate(zip(img_arr, img_names, heatmaps, titles, labels)):
fig, ax = plt.subplots()
img = np.transpose(img, (1, 2, 0))
plt.clf()
plt.imshow(img, cmap='Greys', interpolation='bicubic')
plt.imshow(h_map, cmap=my_cmap, alpha=0.7, interpolation='nearest') #, vmin=-.05, vmax=.05)
plt.colorbar()
plt.axis('off')
plt.title(title)
class_name = CLASSES[y]
class_dir = make_sub_dir(out_dir, class_name)
plt.savefig(join(class_dir, img_name), bbox_inches='tight', dpi=300)
评论列表
文章目录