def generate(self, n_samples=100, batch_size=32, **kwargs):
"""
Sample new data from the generator network.
Args:
n_samples: int, the number of samples to be generated
batch_size: int, number of generated samples are once
Keyword Args:
return_probs: bool, whether the output generations should be raw probabilities or sampled Bernoulli outcomes
latent_samples: ndarray, alternative source of latent encoding, otherwise sampling will be applied
Returns:
The generated data as ndarray of shape (n_samples, data_dim)
"""
return_probs = kwargs.get('return_probs', True)
latent_samples = kwargs.get('latent_samples', None)
if latent_samples is not None:
data_iterator, n_iters = self.data_iterator.iter(latent_samples, batch_size=batch_size, mode='generation')
data_probs = self.generative_model.predict_generator(data_iterator, steps=n_iters)
else:
if self.latent_dim == 2:
# perform 2d grid search
n_samples_per_axis = complex(int(np.sqrt(n_samples)))
uniform_grid = np.mgrid[0.01:0.99:n_samples_per_axis, 0.01:0.99:n_samples_per_axis].reshape(2, -1).T
latent_samples = standard_gaussian.ppf(uniform_grid)
else:
latent_samples = np.random.standard_normal(size=(n_samples, self.latent_dim))
data_iterator, n_iters = self.data_iterator.iter(latent_samples, batch_size=batch_size, mode='generation')
data_probs = self.generative_model.predict_generator(data_iterator, steps=n_iters)
if return_probs:
return data_probs
sampled_data = np.random.binomial(1, p=data_probs)
return sampled_data
评论列表
文章目录