def split(flags):
if os.path.exists(flags.split_path):
return np.load(flags.split_path).item()
folds = flags.folds
path = flags.input_path
random.seed(6)
img_list = ["%s/%s"%(path,img) for img in os.listdir(path)]
random.shuffle(img_list)
dic = {}
n = len(img_list)
num = (n+folds-1)//folds
for i in range(folds):
s,e = i*num,min(i*num+num,n)
dic[i] = img_list[s:e]
np.save(flags.split_path,dic)
return dic
评论列表
文章目录