def test_batch_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("A\nB\nC\nD\nE\n")
batch_size = 3
queue_capacity = 10
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = tf.contrib.learn.io.read_batch_examples(
[filename], batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
read_batch_size=10, name=name)
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
self.assertAllEqual(session.run(inputs), [b"D", b"E"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
评论列表
文章目录