p1b1.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Benchmarks 作者: ECP-CANDLE 项目源码 文件源码
def load_data(shuffle=True, n_cols=None):
    train_path = get_p1_file('http://ftp.mcs.anl.gov/pub/candle/public/benchmarks/P1B1/P1B1.train.csv')
    test_path = get_p1_file('http://ftp.mcs.anl.gov/pub/candle/public/benchmarks/P1B1/P1B1.test.csv')

    usecols = list(range(n_cols)) if n_cols else None

    df_train = pd.read_csv(train_path, engine='c', usecols=usecols)
    df_test = pd.read_csv(test_path, engine='c', usecols=usecols)

    df_train = df_train.drop('case_id', 1).astype(np.float32)
    df_test = df_test.drop('case_id', 1).astype(np.float32)

    if shuffle:
        df_train = df_train.sample(frac=1, random_state=seed)
        df_test = df_test.sample(frac=1, random_state=seed)

    X_train = df_train.as_matrix()
    X_test = df_test.as_matrix()

    scaler = MaxAbsScaler()
    mat = np.concatenate((X_train, X_test), axis=0)
    mat = scaler.fit_transform(mat)

    X_train = mat[:X_train.shape[0], :]
    X_test = mat[X_train.shape[0]:, :]

    return X_train, X_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号