def main(_):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
with tf.Session(config=config) as sess:
model = UNet(batch_size=args.batch_size)
model.register_session(sess)
model.build_model(is_training=False, inst_norm=args.inst_norm)
embedding_ids = [int(i) for i in args.embedding_ids.split(",")]
if not args.interpolate:
if len(embedding_ids) == 1:
embedding_ids = embedding_ids[0]
if args.compare:
model.infer_compare(model_dir=args.model_dir, source_obj=args.source_obj, embedding_ids=embedding_ids,
save_dir=args.save_dir, show_ssim=args.show_ssim)
else:
model.infer(model_dir=args.model_dir, source_obj=args.source_obj, embedding_ids=embedding_ids,
save_dir=args.save_dir, progress_file=args.progress_file)
else:
if len(embedding_ids) < 2:
raise Exception("no need to interpolate yourself unless you are a narcissist")
chains = embedding_ids[:]
if args.uroboros:
chains.append(chains[0])
pairs = list()
for i in range(len(chains) - 1):
pairs.append((chains[i], chains[i + 1]))
for s, e in pairs:
model.interpolate(model_dir=args.model_dir, source_obj=args.source_obj, between=[s, e],
save_dir=args.save_dir, steps=args.steps)
if args.output_gif:
gif_path = os.path.join(args.save_dir, args.output_gif)
compile_frames_to_gif(args.save_dir, gif_path)
print("gif saved at %s" % gif_path)
评论列表
文章目录