def plot_images_and_clusters(images, clusters, epoch, save_path, ncol=10):
'''use multiple images'''
fig = plt.figure()#facecolor='black')
images = np.squeeze(images, -1)
nrow = int(np.ceil(images.shape[0] / float(ncol)))
gs = gridspec.GridSpec(nrow, ncol,
width_ratios=[1]*ncol, height_ratios=[1]*nrow,
# wspace=0.01, hspace=0.001,
# top=0.95, bottom=0.05,
# left=0.05, right=0.95
)
gs.update(wspace=0, hspace=0)
n = 0
for i in range(10):
images_i = images[clusters==i, :, :]
if images_i.shape[0] == 0:
continue
for j in range(images_i.shape[0]):
ax = plt.subplot(gs[n])
n += 1
plt.imshow(images_i[j,:], cmap='gray')
plt.axis('off')
ax.set_aspect('auto')
plt.savefig(os.path.join(save_path, 'plot_gmvae_epoch_{}.png'.format(epoch)), dpi=fig.dpi)
评论列表
文章目录