def __init__(self, fnames, shuffle=True, num_epochs=None):
"""Init from a list of filenames to enqueue.
Args:
fnames: list of .tfrecords filenames to enqueue.
shuffle: if true, shuffle the list at each epoch
"""
self._fnames = fnames
self._fname_queue = tf.train.string_input_producer(
self._fnames,
capacity=1000,
shuffle=shuffle,
num_epochs=num_epochs,
shared_name='input_files')
self._reader = tf.TFRecordReader()
# Read first record to initialize the shape parameters
with tf.Graph().as_default():
fname_queue = tf.train.string_input_producer(self._fnames)
reader = tf.TFRecordReader()
_, serialized = reader.read(fname_queue)
shapes = self._parse_shape(serialized)
dtypes = self._parse_dtype(serialized)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
self.shapes = sess.run(shapes)
self.shapes = {k: self.shapes[k+'_sz'].tolist() for k in self.FEATURES}
self.dtypes = sess.run(dtypes)
self.dtypes = {k: REVERSE_TYPEMAP[self.dtypes[k+'_dtype'][0]] for k in self.FEATURES}
coord.request_stop()
coord.join(threads)
评论列表
文章目录