shuffle_tensor_list.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:master-thesis 作者: AndreasMadsen 项目源码 文件源码
def shuffle_tensor_list(input_tensors, **kwargs):
    dtypes = [tensor.dtype for tensor in input_tensors]

    shuffle_queue = tf.RandomShuffleQueue(dtypes=dtypes, **kwargs)
    shuffle_enqueue = shuffle_queue.enqueue(input_tensors)
    tf.train.add_queue_runner(
        tf.train.QueueRunner(shuffle_queue, [shuffle_enqueue])
    )

    output_tensors = shuffle_queue.dequeue()
    for output_tensor, input_tensor in zip(output_tensors, input_tensors):
        output_tensor.set_shape(input_tensor.get_shape())

    return tuple(output_tensors)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号