def visualize(gen, epoch, savedir, batch_size=36, image_type='sigmoid'):
z = chainer.Variable(gen.xp.asarray(gen.make_hidden(batch_size)), volatile=True)
x_fake = gen(z, train=False)
if image_type == 'sigmoid':
img_gen = ((cuda.to_cpu(x_fake.data)) * 255).clip(0, 255).astype(np.uint8)
else:
img_gen = ((cuda.to_cpu(x_fake.data) + 1) * 127.5).clip(0, 255).astype(np.uint8)
fig = plt.figure(figsize=(9, 9))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
for i in range(36):
ax = fig.add_subplot(6, 6, i + 1, xticks=[], yticks=[])
ax.imshow(img_gen[i].transpose(1, 2, 0))
fig.savefig('{}/generate_{:03d}'.format(savedir, epoch))
# plt.show()
plt.close()
评论列表
文章目录