data.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tensorsandbox 作者: kaizouman 项目源码 文件源码
def train_inputs(data_dir):
    """Construct input for CIFAR training.

    Note that batch_size is a placeholder whose default value is the one
    specified during training. It can however be specified differently at
    inference time by passing it explicitly in the feed dict when sess.run is
    called.

    Args:
        data_dir: Path to the CIFAR-10 data directory.

    Returns:
        images: Images. 4D tensor [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3].
        labels: Labels. 1D tensor [batch_size].
    """

    # Transpose dimensions
    raw_image, label = get_raw_input_data(False, data_dir)

    # If needed, perform data augmentation
    if tf.app.flags.FLAGS.data_aug:
        image = distort_image(raw_image)
    else:
        image = raw_image

    # Normalize image (substract mean and divide by variance)
    float_image = tf.image.per_image_standardization(image)

    # Create a queue to extract batch of samples
    batch_size_tensor = tf.placeholder_with_default(FLAGS.batch_size, shape=[])
    images, labels = tf.train.shuffle_batch([float_image,label],
                                     batch_size = batch_size_tensor,
                                     num_threads = NUM_THREADS,
                                     capacity = 20000 + 3 * FLAGS.batch_size,
                                     min_after_dequeue = 20000)

    # Display the training images in the visualizer
    tf.summary.image('images', images)

    return images, tf.reshape(labels, [-1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号