def visualize(config, vae):
if(config['n_z'] != 2):
print("Skipping visuals since n_z is not 2")
return
nx = ny = 20
x_values = np.linspace(-3, 3, nx)
y_values = np.linspace(-3, 3, ny)
canvas = np.empty((28*ny, 28*nx))
for i, yi in enumerate(x_values):
for j, xi in enumerate(y_values):
z_mu = np.array([[xi, yi]])
x_mean = vae.generate(np.tile(z_mu, [config['batch_size'], 1]))
canvas[(nx-i-1)*28:(nx-i)*28, j*28:(j+1)*28] = x_mean[0].reshape(28, 28)
plt.figure(figsize=(8, 10))
Xi, Yi = np.meshgrid(x_values, y_values)
plt.imshow(canvas, origin="upper")
plt.tight_layout()
img = "samples/2d-visualization.png"
plt.savefig(img)
hc.io.sample(config, [{"label": "2d visualization", "image": img}])
评论列表
文章目录