input.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:LiTeFlow 作者: petrux 项目源码 文件源码
def shuffle(tensors,
            capacity=32,
            min_after_dequeue=16,
            num_threads=1,
            dtypes=None,
            shapes=None,
            seed=None,
            shared_name=None,
            name='shuffle'):
    """Wrapper around a `tf.RandomShuffleQueue` creation.

    Return a dequeue op that dequeues elements from `tensors` in a
    random order, through a `tf.RandomShuffleQueue` -- see for further
    documentation.

    Arguments:
      tensors: an iterable of tensors.
      capacity: (Optional) the capacity of the queue; default value set to 32.
      num_threads: (Optional) the number of threads to be used fo the queue runner;
        default value set to 1.
      min_after_dequeue: (Optional) minimum number of elements to remain in the
        queue after a `dequeue` or `dequeu_many` has been performend,
        in order to ensure better mixing of elements; default value set to 16.
      dtypes: (Optional) list of `DType` objects, one for each tensor in `tensors`;
        if not provided, will be inferred from `tensors`.
      shapes: (Optional) list of shapes, one for each tensor in `tensors`.
      seed: (Optional) seed for random shuffling.
      shared_name: (Optional) If non-empty, this queue will be shared under
        the given name across multiple sessions.
      name: Optional name scope for the ops.

    Returns:
      The tuple of tensors that was randomly dequeued from `tensors`.
    """

    tensors = list(tensors)
    with tf.name_scope(name, tensors):
        dtypes = dtypes or list([t.dtype for t in tensors])
        queue = tf.RandomShuffleQueue(
            seed=seed,
            shared_name=shared_name,
            name='random_shuffle_queue',
            dtypes=dtypes,
            shapes=shapes,
            capacity=capacity,
            min_after_dequeue=min_after_dequeue)
        enqueue = queue.enqueue(tensors)
        runner = tf.train.QueueRunner(queue, [enqueue] * num_threads)
        tf.train.add_queue_runner(runner)
        dequeue = queue.dequeue()
        return dequeue
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号