def get_style_features(FLAGS):
"""
For the "style_image", the preprocessing step is:
1. Resize the shorter side to FLAGS.image_size
2. Apply central crop
"""
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Graph().as_default(), tf.Session(config=config) as sess:
network_fn = nets_factory.get_network_fn(
FLAGS.loss_model,
num_classes=1,
is_training=False)
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
images = tf.expand_dims(utils.get_image(FLAGS.style_image, FLAGS.image_size, FLAGS.image_size, image_preprocessing_fn), 0)
_, endpoints_dict = network_fn(images)
features = []
for layer in FLAGS.style_layers:
feature = endpoints_dict[layer]
features.append(gram(feature))
init_func = utils._get_init_fn(FLAGS)
init_func(sess)
if os.path.exists('generated') is False:
os.makedirs('generated')
save_file = 'generated/target_style_' + FLAGS.naming + '.jpg'
with open(save_file, 'wb') as f:
target_image = unprocess_image(images[0, :])
value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
f.write(sess.run(value))
tf.logging.info('Target style pattern is saved to: %s.' % save_file)
return sess.run(features)
评论列表
文章目录