def fetch_from_config(cfg):
data_set_name = cfg['fetch']['name']
if cfg['fetch'].getboolean('sklearn'):
if data_set_name == 'OLIVETTI':
data_set = skd.fetch_olivetti_faces(shuffle=True)
else:
data_set = skd.fetch_mldata(data_set_name)
X, y = data_set.data, data_set.target
if data_set_name == 'MNIST original':
if cfg['pre_process'].getboolean('normalize'):
X = X / 255.
else:
if data_set_name == 'LETTERS':
X, y = fetch_load_letters()
elif data_set_name == 'ISOLET':
x_tr, x_te, y_tr, y_te = fetch_load_isolet()
elif data_set_name == 'SHREC14':
X, y = load_shrec14(real=cfg['fetch']['real'], desc=cfg['fetch']['desc'])
X = prep.normalize(X, norm=cfg['pre_process']['norm'])
else:
raise NameError('No data set {} found!'.format(data_set_name))
# Separate training and testing set
if data_set_name == 'MNIST original':
x_tr, x_te, y_tr, y_te = X[:60000], X[60000:], y[:60000], y[60000:]
elif data_set_name != 'ISOLET':
test_size = cfg['train_test'].getfloat('test_size')
x_tr, x_te, y_tr, y_te = train_test_split(X, y, test_size=test_size, stratify=y)
return x_tr, x_te, y_tr, y_te
评论列表
文章目录