readfromtfrecords_batch_train.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:SSD_tensorflow_VOC 作者: LevinJ 项目源码 文件源码
def __get_images_labels(self):
        dataset = dataset_factory.get_dataset(
                self.dataset_name, self.dataset_split_name, self.dataset_dir)

        provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=self.num_readers,
                    common_queue_capacity=20 * self.batch_size,
                    common_queue_min=10 * self.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= self.labels_offset

        network_fn = nets_factory.get_network_fn(
                self.model_name,
                num_classes=(dataset.num_classes - self.labels_offset),
                weight_decay=self.weight_decay,
                is_training=True)

        train_image_size = self.train_image_size or network_fn.default_image_size

        preprocessing_name = self.preprocessing_name or self.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                preprocessing_name,
                is_training=True)

        image = image_preprocessing_fn(image, train_image_size, train_image_size)

        images, labels = tf.train.batch(
                [image, label],
                batch_size=self.batch_size,
                num_threads=self.num_preprocessing_threads,
                capacity=5 * self.batch_size)
        labels = slim.one_hot_encoding(
                labels, dataset.num_classes - self.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2)
        images, labels = batch_queue.dequeue()

        return images, labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号