def save_imshow_grid(images, logs_dir, filename, shape):
"""
Plot images in a grid of a given shape.
"""
fig = plt.figure(1)
grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
size = shape[0] * shape[1]
for i in trange(size, desc="Saving images"):
grid[i].axis('off')
grid[i].imshow(images[i])
plt.savefig(os.path.join(logs_dir, filename))
评论列表
文章目录