def __init__(self, img_dir, img_names, pre_process_img_func,
extract_feat_func, batch_size, num_threads,
multi_thread_stacking=False):
"""
Args:
extract_feat_func: External model for extracting features. It takes a
batch of images and returns a batch of features.
multi_thread_stacking: bool, whether to use multi threads to speed up
`np.stack()` or not. When the system is memory overburdened, using
`np.stack()` to stack a batch of images takes ridiculously long time.
E.g. it may take several seconds to stack a batch of 64 images.
"""
self.img_dir = img_dir
self.img_names = img_names
self.pre_process_img_func = pre_process_img_func
self.extract_feat_func = extract_feat_func
self.prefetcher = utils.Prefetcher(
self.get_sample, len(img_names), batch_size, num_threads=num_threads)
self.epoch_done = True
self.multi_thread_stacking = multi_thread_stacking
if multi_thread_stacking:
self.pool = Pool(processes=8)
评论列表
文章目录