def decoder(self, a_seq):
""" Convolutional variational decoder to decode latent code to image reconstruction
If config.conv == False it is a MLP VAE. If config.use_vae == False it is a normal decoder
:param a_seq: latent code
:return: x_hat, x_mu, x_var
"""
# Create decoder
if self.config.out_distr == 'bernoulli':
activation_x_mu = tf.nn.sigmoid
else:
activation_x_mu = None
with tf.variable_scope('vae/decoder'):
a = tf.reshape(a_seq, (-1, self.config.dim_a))
if self.config.conv:
dec_upscale = slim.fully_connected(a, int(np.prod(self.enc_shape)), activation_fn=None)
dec_upscale = tf.reshape(dec_upscale, [-1] + self.enc_shape)
dec_hidden = dec_upscale
for filters in reversed(self.num_filters):
dec_hidden = slim.conv2d(dec_hidden, filters * 4, self.config.filter_size,
activation_fn=self.activation_fn)
dec_hidden = subpixel_reshape(dec_hidden, 2)
x_mu = slim.conv2d(dec_hidden, 1, 1, stride=1, activation_fn=activation_x_mu)
x_var = tf.constant(self.config.noise_pixel_var, dtype=tf.float32, shape=())
else:
dec_hidden = slim.repeat(a, self.config.num_layers, slim.fully_connected,
self.config.vae_num_units, self.activation_fn)
x_mu = slim.fully_connected(dec_hidden, self.d1 * self.d2, activation_fn=activation_x_mu)
x_mu = tf.reshape(x_mu, (-1, self.d1, self.d2, 1))
# x_var is not used for bernoulli outputs. Here we fix the output variance of the Gaussian,
# we could also learn it globally for each pixel (as we did in the pendulum experiment) or through a
# neural network.
x_var = tf.constant(self.config.noise_pixel_var, dtype=tf.float32, shape=())
if self.config.out_distr == 'bernoulli':
# For bernoulli we show the probabilities
x_hat = x_mu
else:
x_hat = simple_sample(x_mu, x_var)
return tf.reshape(x_hat, tf.stack((-1, self.ph_steps, self.d1, self.d2))), x_mu, x_var
评论列表
文章目录