eval.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:prisma 作者: hijkzzz 项目源码 文件源码
def generate():
    if not FLAGS.CONTENT_IMAGE:
        tf.logging.info("train a fast nerual style need to set the Content images path")
        return

    if not os.path.exists(FLAGS.OUTPUT_FOLDER):
        os.mkdir(FLAGS.OUTPUT_FOLDER)

    # ??????
    height = 0
    width = 0
    with open(FLAGS.CONTENT_IMAGE, 'rb') as img:
        with tf.Session().as_default() as sess:
            if FLAGS.CONTENT_IMAGE.lower().endswith('png'):
                image = sess.run(tf.image.decode_png(img.read()))
            else:
                image = sess.run(tf.image.decode_jpeg(img.read()))
            height = image.shape[0]
            width = image.shape[1]
    tf.logging.info('Image size: %dx%d' % (width, height))

    with tf.Graph().as_default(), tf.Session() as sess:
        # ??????
        path = FLAGS.CONTENT_IMAGE
        png = path.lower().endswith('png')
        img_bytes = tf.read_file(path)

        # ????
        content_image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3)
        content_image = tf.image.convert_image_dtype(content_image, tf.float32) * 255.0
        content_image = tf.expand_dims(content_image, 0)

        generated_images = transform.net(content_image - vgg.MEAN_PIXEL, training=False)
        output_format = tf.saturate_cast(generated_images, tf.uint8)

        # ????
        saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
        model_path = os.path.abspath(FLAGS.MODEL_PATH)
        tf.logging.info('Usage model {}'.format(model_path))
        saver.restore(sess, model_path)

        filename = os.path.basename(FLAGS.CONTENT_IMAGE)
        (shotname, extension) = os.path.splitext(filename)
        filename = shotname + '-' + os.path.basename(FLAGS.MODEL_PATH) + extension

        tf.logging.info("image {}".format(filename))
        images_t = sess.run(output_format)

        assert len(images_t) == 1
        misc.imsave(os.path.join(FLAGS.OUTPUT_FOLDER, filename), images_t[0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号