def load_data(train_num, train_repeat):
test_size = (10. - train_num) / 10
data = fetch_olivetti_faces()
X = data.images
y = data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=3, stratify=y)
if train_repeat > 1:
X_train = X_train.repeat(train_repeat, axis=0)
y_train = y_train.repeat(train_repeat)
return X_train, y_train, X_test, y_test
评论列表
文章目录