utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:seqhawkes 作者: mlukasik 项目源码 文件源码
def foldsplitter(taskcolumn, train_set_sizes):
    '''
    For each task id (in passed taskcolumn) take rows from number 
    train_set_sizes up for testing, 
    and all other rows for training (so training consists of both other 
    task ids and of rows from the same task id
    up to number train_set_sizes-1.
    '''

    folds = sorted(list(set(taskcolumn)))
    for fold in folds:
        for train_set_size in train_set_sizes:
            testfold2train = taskcolumn == fold
            cnt = 0
            for (i, x) in enumerate(testfold2train):
                if testfold2train[i]:
                    cnt += 1
                    if cnt > train_set_size:
                        testfold2train[i] = False
            remaining_train = taskcolumn != fold
            x = np.logical_or.reduce([testfold2train, remaining_train])

            yield (x, np.logical_not(x))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号