def save_ims(filename, ims, dpi=100, scale=0.5):
n, c, h, w = ims.shape
rows = int(math.ceil(math.sqrt(n)))
cols = int(round(math.sqrt(n)))
fig, axes = plt.subplots(rows, cols, figsize=(w*cols/dpi*scale, h*rows/dpi*scale), dpi=dpi)
for i, ax in enumerate(axes.flat):
if i < n:
ax.imshow(ims[i].transpose((1, 2, 0)))
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.1, hspace=0.1)
plt.savefig(filename, dpi=dpi, bbox_inces='tight', transparent=True)
plt.clf()
plt.close()
评论列表
文章目录