DataLoader.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:mxbox 作者: Lyken17 项目源码 文件源码
def __init__(self, dataset, feedin_shape, collate_fn=default_collate, threads=1, shuffle=False):
        super(DataLoader, self).__init__()

        self.dataset = dataset
        self.threads = threads
        self.collate_fn = collate_fn(feedin_shape)
        # self.collate_fn = self.default_collate_fn

        # shape related variables

        self.data_shapes = feedin_shape['data']
        self.label_shapes = feedin_shape['label']
        self.batch_size = feedin_shape['batch_size']

        # loader related variables
        self.current = 0
        self.total = len(self.dataset)
        self.shuflle = shuffle
        self.map_index = list(range(self.total))

        # prepare for loading
        self.get_batch = self.get_batch_single_thread
        if self.threads > 1:  # multi process read
            from multiprocessing.dummy import Pool as ThreadPool
            # self.pool = multiprocessing.Pool(self.threads)
            self.pool = ThreadPool(self.threads)
            self.get_batch = self.get_batch_multi_thread

        self.reset()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号