style.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:deep-style-transfer 作者: albertlai 项目源码 文件源码
def main():
    parser = argparse.ArgumentParser()    

    parser.add_argument('--model_file', type=str, default='data/vg-30.pb',
                        help='Pretrained model file to run')

    parser.add_argument('--input', type=str,
                        default='data/sf.jpg',
                        help='Input image to process')    
    parser.add_argument('--output', type=str, default="output.png",
                        help='Output image file')

    args = parser.parse_args()
    logging.basicConfig(stream=sys.stdout,
                            format='%(asctime)s %(levelname)s:%(message)s', 
                            level=logging.INFO,
                            datefmt='%I:%M:%S')

    with open(args.model_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def)
        graph = tf.get_default_graph()

    with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=4)) as session:
        graph_info = session.graph

        logging.info("Initializing graph")
        session.run(tf.initialize_all_variables())

        model_name = os.path.split(args.model_file)[-1][:-3]            
        image = graph.get_tensor_by_name("import/%s/image_in:0" % model_name)
        out = graph.get_tensor_by_name("import/%s/output:0" % model_name)

        shape = image.get_shape().as_list()
        target = [utils.load_image(args.input, image_h=shape[1], image_w=shape[2])]
        logging.info("Processing image")
        start_time = datetime.now()
        processed = session.run(out, feed_dict={image: target})
        logging.info("Processing took %f" % ((datetime.now()-start_time).total_seconds()))
        utils.write_image(args.output, processed)
        logging.info("Done")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号