def __init__(self, dataset, feedin_shape, collate_fn=default_collate, threads=1, shuffle=False):
super(DataLoader, self).__init__()
self.dataset = dataset
self.threads = threads
self.collate_fn = collate_fn(feedin_shape)
# self.collate_fn = self.default_collate_fn
# shape related variables
self.data_shapes = feedin_shape['data']
self.label_shapes = feedin_shape['label']
self.batch_size = feedin_shape['batch_size']
# loader related variables
self.current = 0
self.total = len(self.dataset)
self.shuflle = shuffle
self.map_index = list(range(self.total))
# prepare for loading
self.get_batch = self.get_batch_single_thread
if self.threads > 1: # multi process read
from multiprocessing.dummy import Pool as ThreadPool
# self.pool = multiprocessing.Pool(self.threads)
self.pool = ThreadPool(self.threads)
self.get_batch = self.get_batch_multi_thread
self.reset()
评论列表
文章目录