def translate():
"""
"""
image_path_pairs = prepare_paths()
reals = tf.placeholder(shape=[None, 256, 256, 3], dtype=tf.uint8)
flow = tf.cast(reals, dtype=tf.float32) / 127.5 - 1.0
model = build_cycle_gan(flow, flow, FLAGS.mode)
fakes = tf.saturate_cast(model['fake'] * 127.5 + 127.5, tf.uint8)
# path to checkpoint
ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
tf.train.Saver().restore(session, ckpt_source_path)
for i in range(0, len(image_path_pairs), FLAGS.batch_size):
path_pairs = image_path_pairs[i:i+FLAGS.batch_size]
real_images = [scipy.misc.imread(p[0]) for p in path_pairs]
fake_images = session.run(fakes, feed_dict={reals: real_images})
for idx, path in enumerate(path_pairs):
image = np.concatenate(
[real_images[idx], fake_images[idx]], axis=1)
scipy.misc.imsave(path[1], image)
评论列表
文章目录