def plot_tsne(z_mu, classes, name):
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
model_tsne = TSNE(n_components=2, random_state=0)
z_states = z_mu.data.cpu().numpy()
z_embed = model_tsne.fit_transform(z_states)
classes = classes.data.cpu().numpy()
fig666 = plt.figure()
for ic in range(10):
ind_vec = np.zeros_like(classes)
ind_vec[:, ic] = 1
ind_class = classes[:, ic] == 1
color = plt.cm.Set1(ic)
plt.scatter(z_embed[ind_class, 0], z_embed[ind_class, 1], s=10, color=color)
plt.title("Latent Variable T-SNE per Class")
fig666.savefig('./vae_results/'+str(name)+'_embedding_'+str(ic)+'.png')
fig666.savefig('./vae_results/'+str(name)+'_embedding.png')
评论列表
文章目录