style_transfer.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:ml_styles 作者: imironhead 项目源码 文件源码
def load_image(path):
    """
    """
    file_names = tf.train.string_input_producer([path])

    _, image = tf.WholeFileReader().read(file_names)

    # Decode byte data, no gif please.
    # NOTE: tf.image.decode_image can decode both jpeg and png. However, the
    #       shape (height and width) is unknown.
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32)
    shape = tf.shape(image)[:2]
    image = tf.image.resize_images(image, [256, 256])
    image = tf.reshape(image, [1, 256, 256, 3])

    # for VggNet, subtract the mean color of it's training data.
    # image = tf.subtract(image, VggNet.mean_color_rgb())

    image = tf.cast(image, dtype=tf.float32) / 127.5 - 1.0

    # R/G/B to B/G/R
    image = tf.reverse(image, [3])

    padding = [FLAGS.padding, FLAGS.padding]

    image = tf.pad(
        tensor=image,
        paddings=[[0, 0], padding, padding, [0, 0]],
        mode='symmetric')

    with tf.Session() as session:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        image, shape = session.run([image, shape])

        coord.request_stop()
        coord.join(threads)

        return image, shape
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号