def parse_serialized_examples_batch(serialized_examples_batch, batch_size):
feature_to_tensor = {
'image': tf.FixedLenFeature([], tf.string),
'height': tf.FixedLenFeature([1], tf.int64),
'width': tf.FixedLenFeature([1], tf.int64),
'label': tf.VarLenFeature(tf.int64),
'label_length': tf.FixedLenFeature([1], tf.int64)
}
features = tf.parse_example(serialized_examples_batch, feature_to_tensor)
class ocrRecord(object):
pass
result = ocrRecord()
result.heights = tf.cast(features['height'], tf.int32)
result.widths = tf.cast(features['width'], tf.int32)
result.depth = 1
# shape_1d = result.height * result.width * result.depth
shape_1d = IMAGE_HEIGHT * IMAGE_WIDTH * IMAGE_DEPTH
def decode_image_string(string):
decoded_image = tf.decode_raw(string, tf.uint8)
return tf.cast(decoded_image, tf.uint8)
imgs_1d = tf.map_fn(decode_image_string, features['image'], dtype=tf.uint8,
back_prop=False, parallel_iterations=15)
imgs_1d = tf.reshape(imgs_1d, [batch_size, shape_1d])
imgs_1d.set_shape([batch_size, shape_1d])
result.uint8images = tf.reshape(imgs_1d, [batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH])
result.uint8images.set_shape([batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH])
result.label_lengths = tf.cast(features['label_length'], tf.int32)
result.label_lengths = tf.reshape(result.label_lengths, [batch_size])
result.label_lengths.set_shape([batch_size])
result.labels = tf.cast(features['label'], tf.int32)
# Convert for timestep input
result.uint8image = tf.transpose(result.uint8images, [0, 2, 1, 3])
return result
评论列表
文章目录