def generate(estimator):
from scipy.stats import norm
n = 15 # Figure row size
figure = np.zeros((28 * n, 28 * n))
# Random normal distributions to feed network with
x_axis = norm.ppf(np.linspace(0.05, 0.95, n))
y_axis = norm.ppf(np.linspace(0.05, 0.95, n))
samples = []
for i, x in enumerate(x_axis):
for j, y in enumerate(y_axis):
samples.append(np.array([x, y], dtype=np.float32))
samples = np.array(samples)
x_reconstructed = estimator.generate(
plx.processing.numpy_input_fn({'samples': samples}, batch_size=n * n, shuffle=False))
results = [x['results'] for x in x_reconstructed]
for i, x in enumerate(x_axis):
for j, y in enumerate(y_axis):
digit = results[i * n + j].reshape(28, 28)
figure[i * 28: (i + 1) * 28, j * 28: (j + 1) * 28] = digit
try:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()
except ImportError:
pass
variational_autoencoder_mnist.py 文件源码
python
阅读 17
收藏 0
点赞 0
评论 0
评论列表
文章目录