def hookTrainData(self, sampleIdxs):
assert len(sampleIdxs) > 0, 'we need a non-empty batch list'
source_list, target_list, label_list = [], [], []
for idx in sampleIdxs:
classList = self.trainDict[idx]
img_pair = classList[np.random.choice(self.trainLenClass[idx], 1)]
prev_img = img_pair[0]
next_img = img_pair[1]
label = idx
# print prev_img, next_img, label
source = cv2.imread(os.path.join(self.img_path, prev_img), cv2.IMREAD_COLOR)
target = cv2.imread(os.path.join(self.img_path, next_img), cv2.IMREAD_COLOR)
source_list.append(np.expand_dims(cv2.resize(source, (self.image_size[1], self.image_size[0])), 0))
target_list.append(np.expand_dims(cv2.resize(target, (self.image_size[1], self.image_size[0])) ,0))
label_list.append(np.expand_dims(label, 0))
return np.concatenate(source_list, axis=0), np.concatenate(target_list, axis=0), np.concatenate(label_list, axis=0)
# Adding the channel dimension if images are read in grayscale
# return np.expand_dims(source_list, axis = 3), np.expand_dims(target_list, axis = 3)
评论列表
文章目录