def split_indices(files, labels, label_file, test_size=0.1, random_state=RANDOM_STATE): # <-- Necessary for running with training on melanoma database, not using per_patient
names = get_names(files)
labels = get_labels(names, label_file=label_file, per_patient=False)
spl = cross_validation.StratifiedShuffleSplit(labels,
test_size=test_size,
random_state=random_state,
n_iter=1)
tr, te = next(iter(spl))
return tr, te
评论列表
文章目录