stats.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:EarlyWarning 作者: wjlei1990 项目源码 文件源码
def train(df_train, df_test):
    train_x, train_y = extract_feature_and_y(df_train)
    print("train x and y shape: {0} and {1}".format(
        train_x.shape, train_y.shape))
    test_x, test_y = extract_feature_and_y(df_test)
    print("test x and y shape: {0} and {1}".format(
        test_x.shape, test_y.shape))

    # print("train x nan:", np.isfinite(train_x).any())
    # print("train y nan:", np.isfinite(train_y).any())
    # print("test x nan:", np.isfinite(test_x).any())

    info = train_ridge_linear_model(train_x, train_y, test_x) 
    #info = train_lasso_model(train_x, train_y, test_x) 
    #info = train_EN_model(train_x, train_y, test_x) 

    _mse = mean_squared_error(test_y, info["y"])
    _std = np.std(test_y - info["y"])
    print("MSE on test data: %f" % _mse)
    print("std of error on test data: %f" % _std)

    plot_y(train_y, info["train_y"], test_y, info["y"])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号