ocr_input.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:tf-cnn-lstm-ocr-captcha 作者: Luonic 项目源码 文件源码
def parse_serialized_examples_batch(serialized_examples_batch, batch_size):
    feature_to_tensor = {
        'image': tf.FixedLenFeature([], tf.string),
        'height': tf.FixedLenFeature([1], tf.int64),
        'width': tf.FixedLenFeature([1], tf.int64),
        'label': tf.VarLenFeature(tf.int64),
        'label_length': tf.FixedLenFeature([1], tf.int64)
    }
    features = tf.parse_example(serialized_examples_batch, feature_to_tensor)

    class ocrRecord(object):
        pass

    result = ocrRecord()

    result.heights = tf.cast(features['height'], tf.int32)
    result.widths = tf.cast(features['width'], tf.int32)
    result.depth = 1

    # shape_1d = result.height * result.width * result.depth
    shape_1d = IMAGE_HEIGHT * IMAGE_WIDTH * IMAGE_DEPTH

    def decode_image_string(string):
        decoded_image = tf.decode_raw(string, tf.uint8)
        return tf.cast(decoded_image, tf.uint8)

    imgs_1d = tf.map_fn(decode_image_string, features['image'], dtype=tf.uint8,
                        back_prop=False, parallel_iterations=15)

    imgs_1d = tf.reshape(imgs_1d, [batch_size, shape_1d])
    imgs_1d.set_shape([batch_size, shape_1d])

    result.uint8images = tf.reshape(imgs_1d, [batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH])
    result.uint8images.set_shape([batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH])

    result.label_lengths = tf.cast(features['label_length'], tf.int32)
    result.label_lengths = tf.reshape(result.label_lengths, [batch_size])
    result.label_lengths.set_shape([batch_size])

    result.labels = tf.cast(features['label'], tf.int32)

    # Convert for timestep input
    result.uint8image = tf.transpose(result.uint8images, [0, 2, 1, 3])
    return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号