def visualize(code, filename, filename_r, filename_all):
gen.eval()
generated_by_riter = [[] for _ in range(1+opt.r_iterations)]
for i in xrange((code.size(0) - 1) // opt.batch_size + 1):
batch_size = min(opt.batch_size, code.size(0) - i * opt.batch_size)
batch_code = Variable(code[i * opt.batch_size : i * opt.batch_size + batch_size])
for r_iter in xrange(1+opt.r_iterations):
imgs, _ = gen(batch_code, n_execute_lis_layers=r_iter)
if opt.output_scale:
imgs = imgs * 2 - 1
imgs_np = (imgs.data.cpu().numpy()*255).astype(np.uint8).transpose((0, 2, 3, 1))
generated_by_riter[r_iter].extend(imgs_np)
generated_all = []
for i in xrange(len(generated_by_riter[0])):
block = [imgs[i] for imgs in generated_by_riter]
generated_all.append(np.hstack(block))
misc.imsave(filename, util.draw_grid(generated_by_riter[0], cols=opt.vis_col))
for r_iter in xrange(1, 1+opt.r_iterations):
misc.imsave(filename_r.format(r_iter-1), util.draw_grid(generated_by_riter[r_iter], cols=opt.vis_col))
misc.imsave(filename_all, util.draw_grid(generated_all, cols=opt.vis_col))
gen.train()
评论列表
文章目录