KalmanVariationalAutoencoder.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:kvae 作者: simonkamronn 项目源码 文件源码
def decoder(self, a_seq):
        """ Convolutional variational decoder to decode latent code to image reconstruction
        If config.conv == False it is a MLP VAE. If config.use_vae == False it is a normal decoder
        :param a_seq: latent code
        :return: x_hat, x_mu, x_var
        """
        # Create decoder
        if self.config.out_distr == 'bernoulli':
            activation_x_mu = tf.nn.sigmoid
        else:
            activation_x_mu = None

        with tf.variable_scope('vae/decoder'):
            a = tf.reshape(a_seq, (-1, self.config.dim_a))
            if self.config.conv:
                dec_upscale = slim.fully_connected(a, int(np.prod(self.enc_shape)), activation_fn=None)
                dec_upscale = tf.reshape(dec_upscale, [-1] + self.enc_shape)

                dec_hidden = dec_upscale
                for filters in reversed(self.num_filters):
                    dec_hidden = slim.conv2d(dec_hidden, filters * 4, self.config.filter_size,
                                             activation_fn=self.activation_fn)
                    dec_hidden = subpixel_reshape(dec_hidden, 2)
                x_mu = slim.conv2d(dec_hidden, 1, 1, stride=1, activation_fn=activation_x_mu)
                x_var = tf.constant(self.config.noise_pixel_var, dtype=tf.float32, shape=())
            else:
                dec_hidden = slim.repeat(a, self.config.num_layers, slim.fully_connected,
                                         self.config.vae_num_units, self.activation_fn)

                x_mu = slim.fully_connected(dec_hidden, self.d1 * self.d2, activation_fn=activation_x_mu)
                x_mu = tf.reshape(x_mu, (-1, self.d1, self.d2, 1))
                # x_var is not used for bernoulli outputs. Here we fix the output variance of the Gaussian,
                # we could also learn it globally for each pixel (as we did in the pendulum experiment) or through a
                # neural network.
                x_var = tf.constant(self.config.noise_pixel_var, dtype=tf.float32, shape=())

        if self.config.out_distr == 'bernoulli':
            # For bernoulli we show the probabilities
            x_hat = x_mu
        else:
            x_hat = simple_sample(x_mu, x_var)

        return tf.reshape(x_hat, tf.stack((-1, self.ph_steps, self.d1, self.d2))), x_mu, x_var
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号