hasy_tools.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:HASY 作者: MartinThoma 项目源码 文件源码
def _create_stratified_split(csv_filepath, n_splits):
    """
    Create a stratified split for the classification task.

    Parameters
    ----------
    csv_filepath : str
        Path to a CSV file which points to images
    n_splits : int
        Number of splits to make
    """
    from sklearn.model_selection import StratifiedKFold
    data = _load_csv(csv_filepath)
    labels = [el['symbol_id'] for el in data]
    skf = StratifiedKFold(labels, n_folds=n_splits)
    i = 1
    kdirectory = 'classification-task'
    if not os.path.exists(kdirectory):
            os.makedirs(kdirectory)
    for train_index, test_index in skf:
        print("Create fold %i" % i)
        directory = "%s/fold-%i" % (kdirectory, i)
        if not os.path.exists(directory):
            os.makedirs(directory)
        else:
            print("Directory '%s' already exists. Please remove it." %
                  directory)
        i += 1
        train = [data[el] for el in train_index]
        test_ = [data[el] for el in test_index]
        for dataset, name in [(train, 'train'), (test_, 'test')]:
            with open("%s/%s.csv" % (directory, name), 'wb') as csv_file:
                csv_writer = csv.writer(csv_file)
                csv_writer.writerow(('path', 'symbol_id', 'latex', 'user_id'))
                for el in dataset:
                    csv_writer.writerow(("../../%s" % el['path'],
                                         el['symbol_id'],
                                         el['latex'],
                                         el['user_id']))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号