def ae_reconstruction(enc, dec, eval_folder, gpu, data_iter, batch_size=32, img_chan=3, img_size=64):
@chainer.training.make_extension()
def sample_reconstruction(trainer):
xp = enc.xp
batch = data_iter.next()
d_real = xp.zeros((batch_size, img_chan, img_size, img_size)).astype("f")
for i in range(batch_size):
d_real[i, :] = xp.asarray(batch[i])
x = Variable(d_real, volatile=True)
imgs = dec(enc(x, test=True), test=True)
save_images_grid(imgs, path=eval_folder+"/iter_"+str(trainer.updater.iteration)+".rec.jpg",
grid_w=batch_size//8, grid_h=8)
save_images_grid(d_real, path=eval_folder+"/iter_"+str(trainer.updater.iteration)+".real.jpg",
grid_w=batch_size//8, grid_h=8)
return sample_reconstruction
评论列表
文章目录