def hookTrainData(self, sampleIdxs):
assert len(sampleIdxs) > 0, 'we need a non-empty batch list'
source_list, target_list, label_list = [], [], []
for idx in sampleIdxs:
classList = self.trainDict[idx]
img_pair = classList[np.random.choice(self.trainLenClass[idx], 1)]
prev_img = img_pair[0]
next_img = img_pair[1]
label = idx
# print prev_img, next_img, label
source = cv2.imread(os.path.join(self.img_path, prev_img), cv2.IMREAD_COLOR)
target = cv2.imread(os.path.join(self.img_path, next_img), cv2.IMREAD_COLOR)
if self.is_crop:
source = cv2.resize(source, (self.crop_size[1], self.crop_size[0]))
target = cv2.resize(target, (self.crop_size[1], self.crop_size[0]))
source_list.append(np.expand_dims(source, 0))
target_list.append(np.expand_dims(target, 0))
label_list.append(np.expand_dims(label, 0))
return np.concatenate(source_list, axis=0), np.concatenate(target_list, axis=0), np.concatenate(label_list, axis=0)
评论列表
文章目录