def sample(self):
# Load trained parameters
g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(self.num_epochs))
d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(self.num_epochs))
self.generator.load_state_dict(torch.load(g_path))
self.discriminator.load_state_dict(torch.load(d_path))
self.generator.eval()
self.discriminator.eval()
# Sample the images
noise = self.to_variable(torch.randn(self.sample_size, self.z_dim))
fake_images = self.generator(noise)
sample_path = os.path.join(self.sample_path, 'fake_samples-final.png')
torchvision.utils.save_image(self.denorm(fake_images.data), sample_path, nrow=12)
print("Saved sampled images to '%s'" %sample_path)
评论列表
文章目录