models.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:kdd2017 作者: JinpengLI 项目源码 文件源码
def remove_outliers_by_classifier(X, y, dates, model, m=0.9):
    #xgboost = XGBoost(max_depth=2, num_round=6000)
    if np.isnan(X).any():
        print("X contains NaN")
    if np.isinf(X).any():
        print("X contains inf")
    if np.isnan(np.log(y)).any():
        print("y contains nan")
    if np.isinf(np.log(y)).any():
        print("y contains inf")
    print("X=", X.shape)
    print("y=", y.shape)
    model.fit(X, y)
    y_pred = model.predict(X)
    diff_values = np.abs(y_pred - y)
    abs_diff_vals = np.abs(diff_values)
    sorted_indexes = sorted(range(len(abs_diff_vals)), key = lambda x: abs_diff_vals[x])
    sorted_indexes_lead = sorted_indexes[:int(len(abs_diff_vals)*m)]
    return X[sorted_indexes_lead], y[sorted_indexes_lead], dates[sorted_indexes_lead]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号