base_vae.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:adversarial-variational-bayes 作者: gdikov 项目源码 文件源码
def generate(self, n_samples=100, batch_size=32, **kwargs):
        """
        Sample new data from the generator network.

        Args:
            n_samples: int, the number of samples to be generated 
            batch_size: int, number of generated samples are once

        Keyword Args:
            return_probs: bool, whether the output generations should be raw probabilities or sampled Bernoulli outcomes
            latent_samples: ndarray, alternative source of latent encoding, otherwise sampling will be applied

        Returns:
            The generated data as ndarray of shape (n_samples, data_dim)
        """
        return_probs = kwargs.get('return_probs', True)
        latent_samples = kwargs.get('latent_samples', None)

        if latent_samples is not None:
            data_iterator, n_iters = self.data_iterator.iter(latent_samples, batch_size=batch_size, mode='generation')
            data_probs = self.generative_model.predict_generator(data_iterator, steps=n_iters)
        else:
            if self.latent_dim == 2:
                # perform 2d grid search
                n_samples_per_axis = complex(int(np.sqrt(n_samples)))
                uniform_grid = np.mgrid[0.01:0.99:n_samples_per_axis, 0.01:0.99:n_samples_per_axis].reshape(2, -1).T
                latent_samples = standard_gaussian.ppf(uniform_grid)
            else:
                latent_samples = np.random.standard_normal(size=(n_samples, self.latent_dim))
            data_iterator, n_iters = self.data_iterator.iter(latent_samples, batch_size=batch_size, mode='generation')
            data_probs = self.generative_model.predict_generator(data_iterator, steps=n_iters)
        if return_probs:
            return data_probs
        sampled_data = np.random.binomial(1, p=data_probs)
        return sampled_data
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号