encoder.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:TAC-GAN 作者: dashayushman 项目源码 文件源码
def main(_):
  if (FLAGS.input_image is None or FLAGS.output_codes is None or
      FLAGS.model is None):
    print('\nUsage: python encoder.py --input_image=/your/image/here.png '
          '--output_codes=output_codes.pkl --iteration=15 '
          '--model=residual_gru.pb\n\n')
    return

  if FLAGS.iteration < 0 or FLAGS.iteration > 15:
    print('\n--iteration must be between 0 and 15 inclusive.\n')
    return

  with tf.gfile.FastGFile(FLAGS.input_image) as input_image:
    input_image_str = input_image.read()

  with tf.Graph().as_default() as graph:
    # Load the inference model for encoding.
    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    _ = tf.import_graph_def(graph_def, name='')

    input_tensor = graph.get_tensor_by_name('Placeholder:0')
    outputs = [graph.get_tensor_by_name(name) for name in
               get_output_tensor_names()]

    input_image = tf.placeholder(tf.string)
    _, ext = os.path.splitext(FLAGS.input_image)
    if ext == '.png':
      decoded_image = tf.image.decode_png(input_image, channels=3)
    elif ext == '.jpeg' or ext == '.jpg':
      decoded_image = tf.image.decode_jpeg(input_image, channels=3)
    else:
      assert False, 'Unsupported file format {}'.format(ext)
    decoded_image = tf.expand_dims(decoded_image, 0)

  with tf.Session(graph=graph) as sess:
    img_array = sess.run(decoded_image, feed_dict={input_image:
                                                   input_image_str})
    results = sess.run(outputs, feed_dict={input_tensor: img_array})

  results = results[0:FLAGS.iteration + 1]
  int_codes = np.asarray([x.astype(np.int8) for x in results])

  # Convert int codes to binary.
  int_codes = (int_codes + 1)//2
  export = np.packbits(int_codes.reshape(-1))

  output = io.BytesIO()
  np.savez_compressed(output, shape=int_codes.shape, codes=export)
  with tf.gfile.FastGFile(FLAGS.output_codes, 'w') as code_file:
    code_file.write(output.getvalue())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号