def transfer():
"""
"""
if tf.gfile.IsDirectory(FLAGS.ckpt_path):
ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_path)
elif tf.gfile.Exists(FLAGS.ckpt_path):
ckpt_source_path = FLAGS.ckpt_path
else:
assert False, 'bad checkpoint'
assert tf.gfile.Exists(FLAGS.content_path), 'bad content_path'
assert not tf.gfile.IsDirectory(FLAGS.content_path), 'bad content_path'
image_contents = build_contents_reader()
network = build_style_transfer_network(image_contents, training=False)
#
shape = tf.shape(network['image_styled'])
new_w = shape[1] - 2 * FLAGS.padding
new_h = shape[2] - 2 * FLAGS.padding
image_styled = tf.slice(
network['image_styled'],
[0, FLAGS.padding, FLAGS.padding, 0],
[-1, new_w, new_h, -1])
image_styled = tf.squeeze(image_styled, [0])
image_styled = image_styled * 127.5 + 127.5
image_styled = tf.reverse(image_styled, [2])
image_styled = tf.saturate_cast(image_styled, tf.uint8)
image_styled = tf.image.encode_jpeg(image_styled)
image_styled_writer = tf.write_file(FLAGS.styled_path, image_styled)
with tf.Session() as session:
tf.train.Saver().restore(session, ckpt_source_path)
# make dataset reader work
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
session.run(image_styled_writer)
coord.request_stop()
coord.join(threads)
评论列表
文章目录