def batches(data_file_path, max_number_length, batch_size, size,
num_preprocess_threads=1, is_training=True, channels=1):
filename_queue = tf.train.string_input_producer([data_file_path])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_png': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([max_number_length], tf.int64),
'length': tf.FixedLenFeature([1], tf.int64),
'bbox': tf.FixedLenFeature([4], tf.int64),
})
image, bbox, label, length = features['image_png'], features['bbox'], features['label'], features['length']
bbox = tf.cast(bbox, tf.int32)
dequeued_data = []
for i in range(num_preprocess_threads):
dequeued_img = tf.image.decode_png(image, channels)
dequeued_img = resize_image(dequeued_img, bbox, is_training, size, channels)
dequeued_data.append([dequeued_img, tf.one_hot(length - 1, max_number_length)[0], tf.one_hot(label, 11)])
return tf.train.batch_join(dequeued_data, batch_size=batch_size, capacity=batch_size * 3)
评论列表
文章目录