solver.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch-tutorial 作者: yunjey 项目源码 文件源码
def sample(self):

        # Load trained parameters 
        g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(self.num_epochs))
        d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(self.num_epochs))
        self.generator.load_state_dict(torch.load(g_path))
        self.discriminator.load_state_dict(torch.load(d_path))
        self.generator.eval()
        self.discriminator.eval()

        # Sample the images
        noise = self.to_variable(torch.randn(self.sample_size, self.z_dim))
        fake_images = self.generator(noise)
        sample_path = os.path.join(self.sample_path, 'fake_samples-final.png')
        torchvision.utils.save_image(self.denorm(fake_images.data), sample_path, nrow=12)

        print("Saved sampled images to '%s'" %sample_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号