def split_dataset(dataset, split_ratio, mode):
if mode=='SPLIT_CLASSES':
nrof_classes = len(dataset)
class_indices = np.arange(nrof_classes)
np.random.shuffle(class_indices)
split = int(round(nrof_classes*split_ratio))
train_set = [dataset[i] for i in class_indices[0:split]]
test_set = [dataset[i] for i in class_indices[split:-1]]
elif mode=='SPLIT_IMAGES':
train_set = []
test_set = []
min_nrof_images = 2
for cls in dataset:
paths = cls.image_paths
np.random.shuffle(paths)
split = int(round(len(paths)*split_ratio))
if split<min_nrof_images:
continue # Not enough images for test set. Skip class...
train_set.append(ImageClass(cls.name, paths[0:split]))
test_set.append(ImageClass(cls.name, paths[split:-1]))
else:
raise ValueError('Invalid train/test split mode "%s"' % mode)
return train_set, test_set
评论列表
文章目录