def main(_):
if (FLAGS.input_image is None or FLAGS.output_codes is None or
FLAGS.model is None):
print('\nUsage: python encoder.py --input_image=/your/image/here.png '
'--output_codes=output_codes.pkl --iteration=15 '
'--model=residual_gru.pb\n\n')
return
if FLAGS.iteration < 0 or FLAGS.iteration > 15:
print('\n--iteration must be between 0 and 15 inclusive.\n')
return
with tf.gfile.FastGFile(FLAGS.input_image) as input_image:
input_image_str = input_image.read()
with tf.Graph().as_default() as graph:
# Load the inference model for encoding.
with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_file.read())
_ = tf.import_graph_def(graph_def, name='')
input_tensor = graph.get_tensor_by_name('Placeholder:0')
outputs = [graph.get_tensor_by_name(name) for name in
get_output_tensor_names()]
input_image = tf.placeholder(tf.string)
_, ext = os.path.splitext(FLAGS.input_image)
if ext == '.png':
decoded_image = tf.image.decode_png(input_image, channels=3)
elif ext == '.jpeg' or ext == '.jpg':
decoded_image = tf.image.decode_jpeg(input_image, channels=3)
else:
assert False, 'Unsupported file format {}'.format(ext)
decoded_image = tf.expand_dims(decoded_image, 0)
with tf.Session(graph=graph) as sess:
img_array = sess.run(decoded_image, feed_dict={input_image:
input_image_str})
results = sess.run(outputs, feed_dict={input_tensor: img_array})
results = results[0:FLAGS.iteration + 1]
int_codes = np.asarray([x.astype(np.int8) for x in results])
# Convert int codes to binary.
int_codes = (int_codes + 1)//2
export = np.packbits(int_codes.reshape(-1))
output = io.BytesIO()
np.savez_compressed(output, shape=int_codes.shape, codes=export)
with tf.gfile.FastGFile(FLAGS.output_codes, 'w') as code_file:
code_file.write(output.getvalue())
评论列表
文章目录