main.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:gan-error-avoidance 作者: aleju 项目源码 文件源码
def visualize(code, filename, filename_r, filename_all):
    gen.eval()
    generated_by_riter = [[] for _ in range(1+opt.r_iterations)]

    for i in xrange((code.size(0) - 1) // opt.batch_size + 1):
        batch_size = min(opt.batch_size, code.size(0) - i * opt.batch_size)
        batch_code = Variable(code[i * opt.batch_size : i * opt.batch_size + batch_size])

        for r_iter in xrange(1+opt.r_iterations):
            imgs, _ = gen(batch_code, n_execute_lis_layers=r_iter)
            if opt.output_scale:
                imgs = imgs * 2 - 1
            imgs_np = (imgs.data.cpu().numpy()*255).astype(np.uint8).transpose((0, 2, 3, 1))
            generated_by_riter[r_iter].extend(imgs_np)

    generated_all = []
    for i in xrange(len(generated_by_riter[0])):
        block = [imgs[i] for imgs in generated_by_riter]
        generated_all.append(np.hstack(block))

    misc.imsave(filename, util.draw_grid(generated_by_riter[0], cols=opt.vis_col))
    for r_iter in xrange(1, 1+opt.r_iterations):
        misc.imsave(filename_r.format(r_iter-1), util.draw_grid(generated_by_riter[r_iter], cols=opt.vis_col))
    misc.imsave(filename_all, util.draw_grid(generated_all, cols=opt.vis_col))
    gen.train()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号