train_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:black_holes 作者: codeforgoodconf 项目源码 文件源码
def train_test():
    df = pd.read_csv("data_preprocessed.csv",header=None)

    label_cols = df.columns[0:2]
    Y = df[label_cols]

    feature_cols = df.columns[2:len(df.columns)]
    X = df[feature_cols]

    X_train, X_test, y_train, y_test = train_test_split(X, Y, random_state=1)
    train_df = pd.concat([y_train,X_train],axis=1)
    test_df = pd.concat([y_test,X_test], axis=1)

    return train_df, test_df
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号