inputs.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:num-seq-recognizer 作者: gmlove 项目源码 文件源码
def batches(data_file_path, max_number_length, batch_size, size,
            num_preprocess_threads=1, is_training=True, channels=1):
  filename_queue = tf.train.string_input_producer([data_file_path])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
    serialized_example,
    features={
      'image_png': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([max_number_length], tf.int64),
      'length': tf.FixedLenFeature([1], tf.int64),
      'bbox': tf.FixedLenFeature([4], tf.int64),
    })
  image, bbox, label, length = features['image_png'], features['bbox'], features['label'], features['length']
  bbox = tf.cast(bbox, tf.int32)
  dequeued_data = []
  for i in range(num_preprocess_threads):
    dequeued_img = tf.image.decode_png(image, channels)
    dequeued_img = resize_image(dequeued_img, bbox, is_training, size, channels)
    dequeued_data.append([dequeued_img, tf.one_hot(length - 1, max_number_length)[0], tf.one_hot(label, 11)])

  return tf.train.batch_join(dequeued_data, batch_size=batch_size, capacity=batch_size * 3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号