variational_autoencoder_mnist.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:polyaxon 作者: polyaxon 项目源码 文件源码
def generate(estimator):
    from scipy.stats import norm

    n = 15  # Figure row size
    figure = np.zeros((28 * n, 28 * n))
    # Random normal distributions to feed network with
    x_axis = norm.ppf(np.linspace(0.05, 0.95, n))
    y_axis = norm.ppf(np.linspace(0.05, 0.95, n))

    samples = []
    for i, x in enumerate(x_axis):
        for j, y in enumerate(y_axis):
            samples.append(np.array([x, y], dtype=np.float32))

    samples = np.array(samples)
    x_reconstructed = estimator.generate(
        plx.processing.numpy_input_fn({'samples': samples}, batch_size=n * n, shuffle=False))

    results = [x['results'] for x in x_reconstructed]
    for i, x in enumerate(x_axis):
        for j, y in enumerate(y_axis):
            digit = results[i * n + j].reshape(28, 28)
            figure[i * 28: (i + 1) * 28, j * 28: (j + 1) * 28] = digit

    try:
        import matplotlib.pyplot as plt

        plt.figure(figsize=(10, 10))
        plt.imshow(figure, cmap='Greys_r')
        plt.show()
    except ImportError:
        pass
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号