common.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:HyperGAN 作者: 255BITS 项目源码 文件源码
def sample(self, filename, save_samples):
        gan = self.gan
        generator = gan.generator.sample

        sess = gan.session
        config = gan.config
        x_v, z_v = sess.run([gan.inputs.x, gan.encoder.z])

        sample = sess.run(generator, {gan.inputs.x: x_v, gan.encoder.z: z_v})

        plt.clf()
        fig = plt.figure(figsize=(3,3))
        plt.scatter(*zip(*x_v), c='b')
        plt.scatter(*zip(*sample), c='r')
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.ylabel("z")
        fig.canvas.draw()
        data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        #plt.savefig(filename)
        self.plot(data, filename, save_samples)
        return [{'image': filename, 'label': '2d'}]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号