hsr.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:hsr 作者: pyk 项目源码 文件源码
def read_images(data_dir):
    pattern = os.path.join(data_dir, '*.png')
    filenames = tf.train.match_filenames_once(pattern, name='list_files')

    queue = tf.train.string_input_producer(
        filenames, 
        num_epochs=NUM_EPOCHS, 
        shuffle=True, 
        name='queue')

    reader = tf.WholeFileReader()
    filename, content = reader.read(queue, name='read_image')
    filename = tf.Print(
        filename, 
        data=[filename],
        message='loading: ')
    filename_split = tf.string_split([filename], delimiter='/')
    label_id = tf.string_to_number(tf.substr(filename_split.values[1], 
        0, 1), out_type=tf.int32)
    label = tf.one_hot(
        label_id-1, 
        5, 
        on_value=1.0, 
        off_value=0.0, 
        dtype=tf.float32)

    img_tensor = tf.image.decode_png(
        content, 
        dtype=tf.uint8, 
        channels=3,
        name='img_decode')

    # Preprocess the image, Performs random transformations
    # Random flip
    img_tensor_flip = tf.image.random_flip_left_right(img_tensor)

    # Random brightness
    img_tensor_bri = tf.image.random_brightness(img_tensor_flip, 
        max_delta=0.2)

    # Per-image scaling
    img_tensor_std = tf.image.per_image_standardization(img_tensor_bri)

    min_after_dequeue = 1000
    capacity = min_after_dequeue + 3 * BATCH_SIZE
    example_batch, label_batch = tf.train.shuffle_batch(
        [img_tensor_std, label], 
        batch_size=BATCH_SIZE,
        shapes=[(IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNELS), (NUM_CLASS)],
        capacity=capacity, 
        min_after_dequeue=min_after_dequeue,
        name='train_shuffle')

    return example_batch, label_batch

# `images` is a 4-D tensor with the shape:
# [n_batch, img_height, img_width, n_channel]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号