inputs.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tensorflow-qnd 作者: raviqqe 项目源码 文件源码
def _shuffle(inputs, capacity, min_after_dequeue, num_threads):
    if isinstance(inputs, dict):
        names, dtypes = zip(*[(key, input_.dtype)
                              for key, input_ in inputs.items()])
    else:
        dtypes = [input_.dtype for input_ in inputs]

    queue = tf.RandomShuffleQueue(
        capacity,
        min_after_dequeue,
        dtypes,
        **({'names': names} if isinstance(inputs, dict) else {}))

    tf.train.add_queue_runner(tf.train.QueueRunner(
        queue,
        [queue.enqueue(inputs)] * num_threads))

    shuffled_inputs = queue.dequeue()

    for key, input_ in (inputs.items()
                        if isinstance(inputs, dict) else
                        enumerate(inputs)):
        shuffled_inputs[key].set_shape(input_.get_shape())

    return shuffled_inputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号