def test_postprocess():
"""
This test uses the data in tftestdata2/ to illustrate how to read out
something that has been written as a string but is "really" an integer.
The data in tftestdata2/ids is just a single attribute, namely "ids",
written out as a string but actually it really represents integers.
"""
source_paths = [os.path.join(dir_path, 'tftestdata2/ids')]
postprocess = {'ids': [(tf.string_to_number, (tf.int32, ), {})]}
dp = d.TFRecordsParallelByFileProvider(source_paths,
n_threads=1,
batch_size=20,
shuffle=False,
postprocess=postprocess)
sess = tf.Session()
ops = dp.init_ops()
queue = b.get_queue(ops[0], queue_type='fifo')
enqueue_ops = []
for op in ops:
enqueue_ops.append(queue.enqueue_many(op))
tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(queue, enqueue_ops))
tf.train.start_queue_runners(sess=sess)
K = 20
inputs = queue.dequeue_many(K)
N = 100
testlist = np.arange(K * N) % 160
for i in range(N):
print('%d of %d' % (i, N))
res = sess.run(inputs)
assert_equal(res['ids'], testlist[K * i: K * (i+1)])
评论列表
文章目录