def hookTrainData(self, sampleIdxs):
assert len(sampleIdxs) > 0, 'we need a non-empty batch list'
source_list, target_list, flow_gt = [], [], []
for idx in sampleIdxs:
img_pair = self.trainList[idx]
prev_img = img_pair[0]
next_img = img_pair[1]
source = cv2.imread(os.path.join(self.img_path, prev_img), cv2.IMREAD_COLOR)
target = cv2.imread(os.path.join(self.img_path, next_img), cv2.IMREAD_COLOR)
flow = utils.readFlow(os.path.join(self.data_path, 'training', "flow", (prev_img[:-4] + ".flo")))
if self.is_crop:
source = cv2.resize(source, (self.crop_size[1], self.crop_size[0]))
target = cv2.resize(target, (self.crop_size[1], self.crop_size[0]))
source_list.append(np.expand_dims(source, 0))
target_list.append(np.expand_dims(target, 0))
flow_gt.append(np.expand_dims(flow, 0))
return np.concatenate(source_list, axis=0), np.concatenate(target_list, axis=0), np.concatenate(flow_gt, axis=0)
评论列表
文章目录