infer.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:neural-fonts 作者: periannath 项目源码 文件源码
def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    with tf.Session(config=config) as sess:
        model = UNet(batch_size=args.batch_size)
        model.register_session(sess)
        model.build_model(is_training=False, inst_norm=args.inst_norm)
        embedding_ids = [int(i) for i in args.embedding_ids.split(",")]
        if not args.interpolate:
            if len(embedding_ids) == 1:
                embedding_ids = embedding_ids[0]
            if args.compare:
                model.infer_compare(model_dir=args.model_dir, source_obj=args.source_obj, embedding_ids=embedding_ids,
                        save_dir=args.save_dir, show_ssim=args.show_ssim)
            else:
                model.infer(model_dir=args.model_dir, source_obj=args.source_obj, embedding_ids=embedding_ids,
                        save_dir=args.save_dir, progress_file=args.progress_file)
        else:
            if len(embedding_ids) < 2:
                raise Exception("no need to interpolate yourself unless you are a narcissist")
            chains = embedding_ids[:]
            if args.uroboros:
                chains.append(chains[0])
            pairs = list()
            for i in range(len(chains) - 1):
                pairs.append((chains[i], chains[i + 1]))
            for s, e in pairs:
                model.interpolate(model_dir=args.model_dir, source_obj=args.source_obj, between=[s, e],
                                  save_dir=args.save_dir, steps=args.steps)
            if args.output_gif:
                gif_path = os.path.join(args.save_dir, args.output_gif)
                compile_frames_to_gif(args.save_dir, gif_path)
                print("gif saved at %s" % gif_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号