def test_batch_randomized(self):
batch_size = 17
queue_capacity = 1234
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
inputs = tf.contrib.learn.io.read_batch_examples(
_VALID_FILE_PATTERN, batch_size,
reader=tf.TFRecordReader, randomize_input=True,
queue_capacity=queue_capacity, name=name)
self.assertEqual("%s:1" % name, inputs.name)
file_name_queue_name = "%s/file_name_queue" % name
file_names_name = "%s/input" % file_name_queue_name
example_queue_name = "%s/random_shuffle_queue" % name
op_nodes = test_util.assert_ops_in_graph({
file_names_name: "Const",
file_name_queue_name: "FIFOQueue",
"%s/read/TFRecordReader" % name: "TFRecordReader",
example_queue_name: "RandomShuffleQueue",
name: "QueueDequeueMany"
}, g)
self.assertEqual(
set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0]))
self.assertEqual(
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
评论列表
文章目录