sampling.py 文件源码

python
阅读 49 收藏 0 点赞 0 评论 0

项目:keras-text 作者: raghakot 项目源码 文件源码
def multi_label_train_test_split(y, test_size=0.2):
    """Creates a test split with roughly the same multi-label distribution in `y`.

    Args:
        y: The multi-label outputs.
        test_size: The test size in [0, 1]

    Returns:
        The train and test indices.
    """
    if test_size <= 0 or test_size >= 1:
        raise ValueError("`test_size` should be between 0 and 1")

    # Find the smallest rational number.
    frac = Fraction(test_size).limit_denominator()
    test_folds, total_folds = frac.numerator, frac.denominator
    logger.warn('Inferring test_size as {}/{}. Generating {} folds. The algorithm might fail if denominator is large.'
                .format(test_folds, total_folds, total_folds))

    folds = equal_distribution_folds(y, folds=total_folds)
    test_indices = np.concatenate(folds[:test_folds])
    train_indices = np.concatenate(folds[test_folds:])
    return train_indices, test_indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号