def sample(self, filename, save_samples):
gan = self.gan
generator = gan.generator.sample
sess = gan.session
config = gan.config
x_v, z_v = sess.run([gan.inputs.x, gan.encoder.z])
sample = sess.run(generator, {gan.inputs.x: x_v, gan.encoder.z: z_v})
plt.clf()
fig = plt.figure(figsize=(3,3))
plt.scatter(*zip(*x_v), c='b')
plt.scatter(*zip(*sample), c='r')
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.ylabel("z")
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
#plt.savefig(filename)
self.plot(data, filename, save_samples)
return [{'image': filename, 'label': '2d'}]
评论列表
文章目录