def __get_images_labels(self):
dataset = dataset_factory.get_dataset(
self.dataset_name, self.dataset_split_name, self.dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=self.num_readers,
common_queue_capacity=20 * self.batch_size,
common_queue_min=10 * self.batch_size)
[image, label] = provider.get(['image', 'label'])
label -= self.labels_offset
network_fn = nets_factory.get_network_fn(
self.model_name,
num_classes=(dataset.num_classes - self.labels_offset),
weight_decay=self.weight_decay,
is_training=True)
train_image_size = self.train_image_size or network_fn.default_image_size
preprocessing_name = self.preprocessing_name or self.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=True)
image = image_preprocessing_fn(image, train_image_size, train_image_size)
images, labels = tf.train.batch(
[image, label],
batch_size=self.batch_size,
num_threads=self.num_preprocessing_threads,
capacity=5 * self.batch_size)
labels = slim.one_hot_encoding(
labels, dataset.num_classes - self.labels_offset)
batch_queue = slim.prefetch_queue.prefetch_queue(
[images, labels], capacity=2)
images, labels = batch_queue.dequeue()
return images, labels
readfromtfrecords_batch_train.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录