def load_batch(dataset, batch_size, is_2D = False, preprocess_fn=None, shuffle=False):
"""Loads a batch for training. dataset is class object that is created from the get_split function"""
# First create the data_provider object
data_provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=shuffle,
common_queue_capacity=2 * batch_size,
common_queue_min=batch_size,
num_epochs=None
)
# Obtain the raw image using the get method
if is_2D:
x, y, z, label = data_provider.get(['series/x', 'series/y', 'series/z', 'label'])
raw_series = tf.stack([x, y, z])
raw_series = tf.expand_dims(raw_series, -1)
else:
raw_series, label = data_provider.get(['series', 'label'])
# convert to int32 from int64
label = tf.to_int32(label)
label_one_hot = tf.to_int32(slim.one_hot_encoding(label, dataset.num_classes))
# Perform the correct preprocessing for the series depending if it is training or evaluating
if preprocess_fn:
series = preprocess_fn(raw_series)
else:
series = raw_series
# Batch up the data by enqueing the tensors internally in a FIFO queue and dequeueing many
# elements with tf.train.batch.
series_batch, labels, labels_one_hot = tf.train.batch(
[series, label, label_one_hot],
batch_size=batch_size,
allow_smaller_final_batch=True,
num_threads=1
)
return series_batch, labels, labels_one_hot
评论列表
文章目录