def split_fold(in_pattern, rettrain=True, fold=0, cvs=5, include_vlaidation=True, split_seed=0):
"""
Splits the elements of the in_pattern into training and test sets
:param in_pattern: string of tfrecord patterns
:param rettrain: return training set (True) or leave out set (False)
:param fold: which fold to process
:param cvs: how many folds you want
:param include_vlaidation: include validation set
:return: subset of tfrecords
"""
assert fold < cvs
files = gfile.Glob(in_pattern)
if split_seed > 0:
kf = KFold(n_splits=cvs, shuffle=True, random_state=split_seed)
else:
kf = KFold(n_splits=cvs)
for i, (train, test) in enumerate(kf.split(files)):
if i == fold:
break
if rettrain:
retfiles = list(np.array(files)[train])
else:
retfiles = list(np.array(files)[test])
if include_vlaidation:
addition = [fname.replace('train', 'validate') for fname in retfiles]
retfiles += addition
return retfiles
评论列表
文章目录