Tensorflow读取带有标签的图像

发布于 2021-01-29 14:55:35

我正在使用Tensorflow构建标准的图像分类模型。为此,我有输入图像,每个图像都分配了一个标签({0,1}中的数字)。因此,可以使用以下格式将数据存储在列表中:

/path/to/image_0 label_0
/path/to/image_1 label_1
/path/to/image_2 label_2
...

我想使用TensorFlow的排队系统读取我的数据并将其输入到我的模型中。忽略标签,可以使用string_input_producer和轻松实现wholeFileReader。这里的代码:

def read_my_file_format(filename_queue):
  reader = tf.WholeFileReader()
  key, value = reader.read(filename_queue)
  example = tf.image.decode_png(value)
  return example

#removing label, obtaining list containing /path/to/image_x
image_list = [line[:-2] for line in image_label_list]

input_queue = tf.train.string_input_producer(image_list)                                                     
input_images = read_my_file_format(input_queue)

但是,在该过程中标签丢失了,因为图像数据作为输入管道的一部分被有意地改组了。通过输入队列将标签和图像数据一起推入的最简单方法是什么?

关注者
0
被浏览
116
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    使用slice_input_producer提供了一种更清洁的解决方案。切片输入生产者允许我们创建一个包含任意多个可分离值的输入队列。这个问题的片段如下所示:

    def read_labeled_image_list(image_list_file):
        """Reads a .txt file containing pathes and labeles
        Args:
           image_list_file: a .txt file with one /path/to/image per line
           label: optionally, if set label will be pasted after each line
        Returns:
           List with all filenames in file image_list_file
        """
        f = open(image_list_file, 'r')
        filenames = []
        labels = []
        for line in f:
            filename, label = line[:-1].split(' ')
            filenames.append(filename)
            labels.append(int(label))
        return filenames, labels
    
    def read_images_from_disk(input_queue):
        """Consumes a single filename and label as a ' '-delimited string.
        Args:
          filename_and_label_tensor: A scalar string tensor.
        Returns:
          Two tensors: the decoded image, and the string label.
        """
        label = input_queue[1]
        file_contents = tf.read_file(input_queue[0])
        example = tf.image.decode_png(file_contents, channels=3)
        return example, label
    
    # Reads pfathes of images together with their labels
    image_list, label_list = read_labeled_image_list(filename)
    
    images = ops.convert_to_tensor(image_list, dtype=dtypes.string)
    labels = ops.convert_to_tensor(label_list, dtype=dtypes.int32)
    
    # Makes an input queue
    input_queue = tf.train.slice_input_producer([images, labels],
                                                num_epochs=num_epochs,
                                                shuffle=True)
    
    image, label = read_images_from_disk(input_queue)
    
    # Optional Preprocessing or Data Augmentation
    # tf.image implements most of the standard image augmentation
    image = preprocess_image(image)
    label = preprocess_label(label)
    
    # Optional Image and Label Batching
    image_batch, label_batch = tf.train.batch([image, label],
                                              batch_size=batch_size)
    

    又见generic_input_producerTensorVision全输入管道的例子。



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看