gmvae.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:vi_vae_gmm 作者: wangg12 项目源码 文件源码
def plot_images_and_clusters(images, clusters, epoch, save_path, ncol=10):
    '''use multiple images'''
    fig = plt.figure()#facecolor='black')
    images = np.squeeze(images, -1)

    nrow = int(np.ceil(images.shape[0] / float(ncol)))
    gs = gridspec.GridSpec(nrow, ncol,
                        width_ratios=[1]*ncol, height_ratios=[1]*nrow,
        #                         wspace=0.01, hspace=0.001,
        #                         top=0.95, bottom=0.05,
        #                         left=0.05, right=0.95
                        )
    gs.update(wspace=0, hspace=0)
    n = 0
    for i in range(10):
        images_i = images[clusters==i, :, :]
        if images_i.shape[0] == 0:
            continue

        for j in range(images_i.shape[0]):
            ax = plt.subplot(gs[n])
            n += 1
            plt.imshow(images_i[j,:], cmap='gray')
            plt.axis('off')
            ax.set_aspect('auto')
    plt.savefig(os.path.join(save_path, 'plot_gmvae_epoch_{}.png'.format(epoch)), dpi=fig.dpi)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号