loader.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:GestureRecognition 作者: gkchai 项目源码 文件源码
def load_batch(dataset, batch_size, is_2D = False, preprocess_fn=None, shuffle=False):
    """Loads a batch for training. dataset is class object that is created from the get_split function"""

    # First create the data_provider object
    data_provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=shuffle,
        common_queue_capacity=2 * batch_size,
        common_queue_min=batch_size,
        num_epochs=None
    )

    # Obtain the raw image using the get method

    if is_2D:
        x, y, z, label = data_provider.get(['series/x', 'series/y', 'series/z', 'label'])
        raw_series = tf.stack([x, y, z])
        raw_series = tf.expand_dims(raw_series, -1)

    else:
        raw_series, label = data_provider.get(['series', 'label'])

    # convert to int32 from int64
    label = tf.to_int32(label)

    label_one_hot = tf.to_int32(slim.one_hot_encoding(label, dataset.num_classes))

    # Perform the correct preprocessing for the series depending if it is training or evaluating
    if preprocess_fn:
        series = preprocess_fn(raw_series)
    else:
        series = raw_series

    # Batch up the data by enqueing the tensors internally in a FIFO queue and dequeueing many
    # elements with tf.train.batch.
    series_batch, labels, labels_one_hot = tf.train.batch(
        [series, label, label_one_hot],
        batch_size=batch_size,
        allow_smaller_final_batch=True,
        num_threads=1
    )
    return series_batch, labels, labels_one_hot
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号