vae-mnist.py 文件源码

python
阅读 45 收藏 0 点赞 0 评论 0

项目:hyperchamber 作者: 255BITS 项目源码 文件源码
def visualize(config, vae):
    if(config['n_z'] != 2):
        print("Skipping visuals since n_z is not 2")
        return
    nx = ny = 20
    x_values = np.linspace(-3, 3, nx)
    y_values = np.linspace(-3, 3, ny)

    canvas = np.empty((28*ny, 28*nx))
    for i, yi in enumerate(x_values):
        for j, xi in enumerate(y_values):
            z_mu = np.array([[xi, yi]])
            x_mean = vae.generate(np.tile(z_mu, [config['batch_size'], 1]))
            canvas[(nx-i-1)*28:(nx-i)*28, j*28:(j+1)*28] = x_mean[0].reshape(28, 28)

    plt.figure(figsize=(8, 10))        
    Xi, Yi = np.meshgrid(x_values, y_values)
    plt.imshow(canvas, origin="upper")
    plt.tight_layout()

    img = "samples/2d-visualization.png"
    plt.savefig(img)
    hc.io.sample(config, [{"label": "2d visualization", "image": img}])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号