def _create_stratified_split(csv_filepath, n_splits):
"""
Create a stratified split for the classification task.
Parameters
----------
csv_filepath : str
Path to a CSV file which points to images
n_splits : int
Number of splits to make
"""
from sklearn.model_selection import StratifiedKFold
data = _load_csv(csv_filepath)
labels = [el['symbol_id'] for el in data]
skf = StratifiedKFold(labels, n_folds=n_splits)
i = 1
kdirectory = 'classification-task'
if not os.path.exists(kdirectory):
os.makedirs(kdirectory)
for train_index, test_index in skf:
print("Create fold %i" % i)
directory = "%s/fold-%i" % (kdirectory, i)
if not os.path.exists(directory):
os.makedirs(directory)
else:
print("Directory '%s' already exists. Please remove it." %
directory)
i += 1
train = [data[el] for el in train_index]
test_ = [data[el] for el in test_index]
for dataset, name in [(train, 'train'), (test_, 'test')]:
with open("%s/%s.csv" % (directory, name), 'wb') as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(('path', 'symbol_id', 'latex', 'user_id'))
for el in dataset:
csv_writer.writerow(("../../%s" % el['path'],
el['symbol_id'],
el['latex'],
el['user_id']))
评论列表
文章目录