def inference():
graph = tf.Graph()
with graph.as_default():
with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
image_data = f.read()
input_image = tf.image.decode_jpeg(image_data, channels=3)
input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
input_image = utils.convert2float(input_image)
input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])
with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_file.read())
[output_image] = tf.import_graph_def(graph_def,
input_map={'input_image': input_image},
return_elements=['output_image:0'],
name='output')
with tf.Session(graph=graph) as sess:
generated = output_image.eval()
with open(FLAGS.output, 'wb') as f:
f.write(generated)
评论列表
文章目录