def _build(self):
# Find split file from which we are going to read.
split_path = os.path.join(
self._dataset_dir, '{}.tfrecords'.format(self._split)
)
if not tf.gfile.Exists(split_path):
raise InvalidDataDirectory(
'"{}" does not exist.'.format(split_path)
)
# String input producer allows for a variable number of files to read
# from. We just know we have a single file.
filename_queue = tf.train.string_input_producer(
[split_path], num_epochs=self._num_epochs, seed=self._seed
)
# Define reader to parse records.
reader = tf.TFRecordReader()
_, raw_record = reader.read(filename_queue)
values, dtypes, names = self.read_record(raw_record)
if self._random_shuffle:
queue = tf.RandomShuffleQueue(
capacity=100,
min_after_dequeue=0,
dtypes=dtypes,
names=names,
name='tfrecord_random_queue',
seed=self._seed
)
else:
queue = tf.FIFOQueue(
capacity=100,
dtypes=dtypes,
names=names,
name='tfrecord_fifo_queue'
)
# Generate queueing ops for QueueRunner.
enqueue_ops = [queue.enqueue(values)] * self._total_queue_ops
self.queue_runner = tf.train.QueueRunner(queue, enqueue_ops)
tf.train.add_queue_runner(self.queue_runner)
return queue.dequeue()
评论列表
文章目录