TF slice_input_producer不使张量保持同步

发布于 2021-01-29 14:56:13

我正在将图像读取到我的TF网络中,但是我还需要相关的标签以及它们。

因此,我尝试遵循此答案,但是输出的标签实际上与每个批次中获取的图像都不匹配。

我的图像名称采用格式dir/3.jpg,因此我只是从图像文件名称中提取标签。

truth_filenames_np = ...
truth_filenames_tf = tf.convert_to_tensor(truth_filenames_np)

# get the labels
labels = [f.rsplit("/", 1)[1] for f in truth_filenames_np]

labels_tf = tf.convert_to_tensor(labels)

# *** This line should make sure both input tensors are synced (from my limited understanding)
# My list is also already shuffled, so I set shuffle=False
truth_image_name, truth_label = tf.train.slice_input_producer([truth_filenames_tf, labels_tf], shuffle=False)


truth_image_value = tf.read_file(truth_image_name)
truth_image = tf.image.decode_jpeg(truth_image_value)
truth_image.set_shape([IMAGE_DIM, IMAGE_DIM, 3])
truth_image = tf.cast(truth_image, tf.float32)
truth_image = truth_image/255.0

# Another key step, where I batch them together
truth_images_batch, truth_label_batch = tf.train.batch([truth_image, truth_label], batch_size=mb_size)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(epochs):
        print "Epoch ", i
        X_truth_batch = truth_images_batch.eval()
        X_label_batch = truth_label_batch.eval()

        # Here I display all the images in this batch, and then I check which file numbers they actually are. 
        # BUT, the images that are displayed don't correspond with what is printed by X_label_batch!
        print X_label_batch
        plot_batch(X_truth_batch)



    coord.request_stop()
    coord.join(threads)

我是在做错什么,还是slice_input_producer实际上没有确保其输入张量得到同步?

在旁边:

我还注意到,当我从tf.train.batch获取批次时,该批次中的元素在我给它的原始列表中彼此相邻,但是批次顺序不在原始顺序中。示例:如果我的数据是[“
dir / 1.jpg”,“ dir / 2.jpg”,“ dir / 3.jpg”,“ dir / 4.jpg”,“ dir / 5.jpg,” dir
/ 6 .jpg“],那么我可能会得到批处理(batch_size = 2)[” dir / 3.jpg“,” dir / 4.jpg“],然后是批处理[”
dir / 1.jpg“,” dir / 2 .jpg“],然后是最后一个。因此,由于订单与批处理订单不匹配,因此甚至很难为标签使用FIFO队列。

关注者
0
被浏览
70
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    这是一个完整的可运行示例,它重现了该问题:

    import tensorflow as tf
    
    truth_filenames_np = ['dir/%d.jpg' % j for j in range(66)]
    truth_filenames_tf = tf.convert_to_tensor(truth_filenames_np)
    # get the labels
    labels = [f.rsplit("/", 1)[1] for f in truth_filenames_np]
    labels_tf = tf.convert_to_tensor(labels)
    
    # My list is also already shuffled, so I set shuffle=False
    truth_image_name, truth_label = tf.train.slice_input_producer(
        [truth_filenames_tf, labels_tf], shuffle=False)
    
    # # Another key step, where I batch them together
    # truth_images_batch, truth_label_batch = tf.train.batch(
    #     [truth_image_name, truth_label], batch_size=11)
    
    epochs = 7
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(epochs):
            print("Epoch ", i)
            X_truth_batch = truth_image_name.eval()
            X_label_batch = truth_label.eval()
            # Here I display all the images in this batch, and then I check
            # which file numbers they actually are.
            # BUT, the images that are displayed don't correspond with what is
            # printed by X_label_batch!
            print(X_truth_batch)
            print(X_label_batch)
        coord.request_stop()
        coord.join(threads)
    

    打印的内容是:

    Epoch  0
    b'dir/0.jpg'
    b'1.jpg'
    Epoch  1
    b'dir/2.jpg'
    b'3.jpg'
    Epoch  2
    b'dir/4.jpg'
    b'5.jpg'
    Epoch  3
    b'dir/6.jpg'
    b'7.jpg'
    Epoch  4
    b'dir/8.jpg'
    b'9.jpg'
    Epoch  5
    b'dir/10.jpg'
    b'11.jpg'
    Epoch  6
    b'dir/12.jpg'
    b'13.jpg'
    

    因此,基本上每个eval调用都会再次运行该操作!添加批处理对此没有任何影响-只是打印批处理(前11个文件名,后11个标签,依此类推)

    我看到的解决方法是:

    for i in range(epochs):
        print("Epoch ", i)
        pair = tf.convert_to_tensor([truth_image_name, truth_label]).eval()
        print(pair[0])
        print(pair[1])
    

    正确打印:

    Epoch  0
    b'dir/0.jpg'
    b'0.jpg'
    Epoch  1
    b'dir/1.jpg'
    b'1.jpg'
    # ...
    

    但对于违反最不惊奇原则的行为却无能为力。

    编辑 :另一种方法:

    import tensorflow as tf
    
    truth_filenames_np = ['dir/%d.jpg' % j for j in range(66)]
    truth_filenames_tf = tf.convert_to_tensor(truth_filenames_np)
    labels = [f.rsplit("/", 1)[1] for f in truth_filenames_np]
    labels_tf = tf.convert_to_tensor(labels)
    truth_image_name, truth_label = tf.train.slice_input_producer(
        [truth_filenames_tf, labels_tf], shuffle=False)
    epochs = 7
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        tf.train.start_queue_runners(sess=sess)
        for i in range(epochs):
            print("Epoch ", i)
            X_truth_batch, X_label_batch = sess.run(
                [truth_image_name, truth_label])
            print(X_truth_batch)
            print(X_label_batch)
    

    这是一个更好的方法,因为tf.convert_to_tensor并且co只接受相同类型/形状等的张量。

    请注意,为简单起见,我删除了协调器,但是会导致警告:

    W c:\ tf_jenkins \ home \ workspace \ release-win \ device \ cpu \ os \
    windows \ tensorflow \ core \ kernels \ queue_base.cc:294] _0_input_producer
    / input_producer / fraction_of_32_full / fraction_of_32_full:跳过未关闭队列的取消入队尝试

    看到这个



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看