splitutils.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:Y8M 作者: mpekalski 项目源码 文件源码
def split_fold(in_pattern, rettrain=True, fold=0, cvs=5, include_vlaidation=True, split_seed=0):
    """
    Splits the elements of the in_pattern into training and test sets
    :param in_pattern: string of tfrecord patterns
    :param rettrain: return training set (True) or leave out set (False)
    :param fold: which fold to process
    :param cvs: how many folds you want
    :param include_vlaidation: include validation set
    :return: subset of tfrecords
    """
    assert fold < cvs

    files = gfile.Glob(in_pattern)
    if split_seed > 0:
        kf = KFold(n_splits=cvs, shuffle=True, random_state=split_seed)
    else:
        kf = KFold(n_splits=cvs)

    for i, (train, test) in enumerate(kf.split(files)):
        if i == fold:
            break

    if rettrain:
        retfiles = list(np.array(files)[train])
    else:
        retfiles = list(np.array(files)[test])

    if include_vlaidation:
        addition = [fname.replace('train', 'validate') for fname in retfiles]
        retfiles += addition

    return retfiles
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号