def load_data(shuffle=True, n_cols=None):
train_path = get_p1_file('http://ftp.mcs.anl.gov/pub/candle/public/benchmarks/P1B1/P1B1.train.csv')
test_path = get_p1_file('http://ftp.mcs.anl.gov/pub/candle/public/benchmarks/P1B1/P1B1.test.csv')
usecols = list(range(n_cols)) if n_cols else None
df_train = pd.read_csv(train_path, engine='c', usecols=usecols)
df_test = pd.read_csv(test_path, engine='c', usecols=usecols)
df_train = df_train.drop('case_id', 1).astype(np.float32)
df_test = df_test.drop('case_id', 1).astype(np.float32)
if shuffle:
df_train = df_train.sample(frac=1, random_state=seed)
df_test = df_test.sample(frac=1, random_state=seed)
X_train = df_train.as_matrix()
X_test = df_test.as_matrix()
scaler = MaxAbsScaler()
mat = np.concatenate((X_train, X_test), axis=0)
mat = scaler.fit_transform(mat)
X_train = mat[:X_train.shape[0], :]
X_test = mat[X_train.shape[0]:, :]
return X_train, X_test
评论列表
文章目录