splitting_utils.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:edm2016 作者: Knewton 项目源码 文件源码
def split_data(data, num_folds, seed=0):
    """ Split all interactions into K-fold sets of training and test dataframes.  Splitting is done
    by assigning student ids to the training or test sets.

    :param pd.DataFrame data: all interactions
    :param int num_folds: number of folds
    :param int seed: seed for the splitting
    :return: a generator over (train dataframe, test dataframe) tuples
    :rtype: generator[(pd.DataFrame, pd.DataFrame)]
    """
    # break up students into folds
    fold_student_idx = _get_fold_student_idx(np.unique(data[USER_IDX_KEY]), num_folds=num_folds,
                                             seed=seed)

    for fold_test_student_idx in fold_student_idx:
        test_idx = np.in1d(data[USER_IDX_KEY], fold_test_student_idx)
        train_idx = np.logical_not(test_idx)
        yield (data[train_idx].copy(), data[test_idx].copy())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号