def save_imshow_grid(images, logs_dir, filename, shape):
"""
Plot images in a grid of a given shape.
"""
pickle.dump(images, open(os.path.join(logs_dir, "image.pk"), "wb"))
fig = plt.figure(1)
grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
size = shape[0] * shape[1]
for i in trange(size, desc="Saving images"):
grid[i].axis('off')
grid[i].imshow(images[i])
Image.fromarray(images[i]).save(os.path.join(logs_dir,str(i)),"jpeg")
plt.savefig(os.path.join(logs_dir, filename))
评论列表
文章目录