vae_plots.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def plot_tsne(z_mu, classes, name):
    import numpy as np
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from sklearn.manifold import TSNE
    model_tsne = TSNE(n_components=2, random_state=0)
    z_states = z_mu.data.cpu().numpy()
    z_embed = model_tsne.fit_transform(z_states)
    classes = classes.data.cpu().numpy()
    fig666 = plt.figure()
    for ic in range(10):
        ind_vec = np.zeros_like(classes)
        ind_vec[:, ic] = 1
        ind_class = classes[:, ic] == 1
        color = plt.cm.Set1(ic)
        plt.scatter(z_embed[ind_class, 0], z_embed[ind_class, 1], s=10, color=color)
        plt.title("Latent Variable T-SNE per Class")
        fig666.savefig('./vae_results/'+str(name)+'_embedding_'+str(ic)+'.png')
    fig666.savefig('./vae_results/'+str(name)+'_embedding.png')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号