losses.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:fast-neural-style 作者: coder-james 项目源码 文件源码
def get_style_features(FLAGS):
    """
    For the "style_image", the preprocessing step is:
    1. Resize the shorter side to FLAGS.image_size
    2. Apply central crop
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth=True
    with tf.Graph().as_default(), tf.Session(config=config) as sess:
      network_fn = nets_factory.get_network_fn(
          FLAGS.loss_model,
          num_classes=1,
          is_training=False)

      image_preprocessing_fn = preprocessing_factory.get_preprocessing(
          FLAGS.loss_model,
          is_training=False)

      images = tf.expand_dims(utils.get_image(FLAGS.style_image, FLAGS.image_size, FLAGS.image_size, image_preprocessing_fn), 0)
      _, endpoints_dict = network_fn(images)

      features = []
      for layer in FLAGS.style_layers:
          feature = endpoints_dict[layer]
          features.append(gram(feature))

      init_func = utils._get_init_fn(FLAGS)
      init_func(sess)
      if os.path.exists('generated') is False:
          os.makedirs('generated')
      save_file = 'generated/target_style_' + FLAGS.naming + '.jpg'
      with open(save_file, 'wb') as f:
          target_image = unprocess_image(images[0, :])
          value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
          f.write(sess.run(value))
          tf.logging.info('Target style pattern is saved to: %s.' % save_file)
      return sess.run(features)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号