def callback(itr):
def samplefun(num_samps):
import numpy as np
z = np.array(np.random.randn(num_samps, zdim), dtype=np.float32)
return decode(z).eval(session=sess)
viz.plot_samples(itr, samplefun, savedir='vae_mnist_samples')
def sample_z(mu, log_sigmasq, M=5):
eps = tf.random_normal((M, zdim), dtype=tf.float32)
return mu + tf.exp(0.5 * log_sigmasq) * eps
def recons(num_samps):
# random subset
subset = X[np.random.choice(X.shape[0], 1)]
mu, log_sigmasq = encode(subset)
imgs = decode(sample_z(mu, log_sigmasq, M=24)).eval(session=sess)
return np.row_stack([subset, imgs])
viz.plot_samples(itr, recons, savedir='vae_mnist_samples', stub='recon')
test_lb = test_lb_fun.eval(session=sess) * Ntest
print "test data VLB: ", np.mean(test_lb)
##########################################
# Make gradient descent fitting function #
##########################################
评论列表
文章目录