def get_sample_dataset(dataset_properties):
"""Returns sample dataset
Args:
dataset_properties (dict): Dictionary corresponding to the properties of the dataset
used to verify the estimator and metric generators.
Returns:
X (array-like): Features array
y (array-like): Labels array
splits (iterator): This is an iterator that returns train test splits for
cross-validation purposes on ``X`` and ``y``.
"""
kwargs = dataset_properties.copy()
data_type = kwargs.pop('type')
if data_type == 'multiclass':
try:
X, y = datasets.make_classification(random_state=8, **kwargs)
splits = model_selection.StratifiedKFold(n_splits=2, random_state=8).split(X, y)
except Exception as e:
raise exceptions.UserError(repr(e))
elif data_type == 'iris':
X, y = datasets.load_iris(return_X_y=True)
splits = model_selection.StratifiedKFold(n_splits=2, random_state=8).split(X, y)
elif data_type == 'mnist':
X, y = datasets.load_digits(return_X_y=True)
splits = model_selection.StratifiedKFold(n_splits=2, random_state=8).split(X, y)
elif data_type == 'breast_cancer':
X, y = datasets.load_breast_cancer(return_X_y=True)
splits = model_selection.StratifiedKFold(n_splits=2, random_state=8).split(X, y)
elif data_type == 'boston':
X, y = datasets.load_boston(return_X_y=True)
splits = model_selection.KFold(n_splits=2, random_state=8).split(X)
elif data_type == 'diabetes':
X, y = datasets.load_diabetes(return_X_y=True)
splits = model_selection.KFold(n_splits=2, random_state=8).split(X)
else:
raise exceptions.UserError('Unknown dataset type {}'.format(dataset_properties['type']))
return X, y, splits
评论列表
文章目录