olivetti_face.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:gcforest 作者: w821881341 项目源码 文件源码
def load_data(train_num, train_repeat):
    test_size = (10. - train_num) / 10
    data = fetch_olivetti_faces()
    X = data.images
    y = data.target
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=3, stratify=y)
    if train_repeat > 1:
        X_train = X_train.repeat(train_repeat, axis=0)
        y_train = y_train.repeat(train_repeat)
    return X_train, y_train, X_test, y_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号