data_pipeline.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:hdrnet_legacy 作者: mgharbi 项目源码 文件源码
def __init__(self, fnames, shuffle=True, num_epochs=None):
    """Init from a list of filenames to enqueue.

    Args:
      fnames: list of .tfrecords filenames to enqueue.
      shuffle: if true, shuffle the list at each epoch
    """
    self._fnames = fnames
    self._fname_queue = tf.train.string_input_producer(
        self._fnames,
        capacity=1000,
        shuffle=shuffle,
        num_epochs=num_epochs,
        shared_name='input_files')
    self._reader = tf.TFRecordReader()

    # Read first record to initialize the shape parameters
    with tf.Graph().as_default():
      fname_queue = tf.train.string_input_producer(self._fnames)
      reader = tf.TFRecordReader()
      _, serialized = reader.read(fname_queue)
      shapes = self._parse_shape(serialized)
      dtypes = self._parse_dtype(serialized)

      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True

      with tf.Session(config=config) as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        self.shapes = sess.run(shapes)
        self.shapes = {k: self.shapes[k+'_sz'].tolist() for k in self.FEATURES}

        self.dtypes = sess.run(dtypes)
        self.dtypes = {k: REVERSE_TYPEMAP[self.dtypes[k+'_dtype'][0]] for k in self.FEATURES}

        coord.request_stop()
        coord.join(threads)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号