support.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:examples 作者: guildai 项目源码 文件源码
def data_inputs(data_dir, data_type, batch_size, runner_threads):

    # Input file reader
    filenames = input_filenames(data_dir, data_type)
    queue = tf.train.string_input_producer(filenames)
    reader = tf.FixedLengthRecordReader(record_bytes=INPUT_RECORD_BYTES)

    # Decode label and image
    _key, record_raw = reader.read(queue)
    record = tf.decode_raw(record_raw, tf.uint8)
    label = tf.cast(tf.slice(record, [0], [INPUT_LABEL_BYTES]), tf.int32)
    image = tf.reshape(
        tf.slice(record, [INPUT_LABEL_BYTES], [INPUT_IMAGE_BYTES]),
        [IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH])

    # Transpose image from stored DHW to HWD
    image_hwd = tf.transpose(image, [1, 2, 0])

    # Finalize image
    image_float = tf.cast(image_hwd, tf.float32)
    if data_type == AUGMENTED_TRAINING_DATA:
        image_final = augmented_standardized_image(image_float)
    else:
        image_final = standardized_image(image_float)

    # Process image and labels using queue runner
    images, labels = tf.train.batch(
        [image_final, label],
        batch_size=batch_size,
        num_threads=runner_threads,
        capacity=10 * batch_size)
    return images, tf.reshape(labels, [batch_size])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号