s13_analyze_xgboost_data.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:KAGGLE_AVITO_2016 作者: ZFTurbo 项目源码 文件源码
def output_critical_tests(train, features, target, model_path, test_size):
    out_path = "cache/fails.html"
    out = open(out_path, "w", encoding='utf-8')
    gbm = xgb.Booster()
    gbm.load_model(model_path)

    types2 = {
        'itemID': np.dtype(int),
        'categoryID': np.dtype(int),
        'title': np.dtype(str),
        'description': np.dtype(str),
        'images_array': np.dtype(str),
        'attrsJSON': np.dtype(str),
        'price': np.dtype(float),
        'locationID': np.dtype(int),
        'metroID': np.dtype(float),
        'lat': np.dtype(float),
        'lon': np.dtype(float),
    }

    print("Load ItemInfo_train.csv")
    items = pd.read_csv("../input/ItemInfo_train.csv", dtype=types2)
    items.fillna(-1, inplace=True)

    split = round((1-test_size)*len(train.index))
    X_train = train[0:split]
    X_valid = train[split:]
    print('Length train:', len(X_train.index))
    print('Length valid:', len(X_valid.index))

    print("Validating...")
    check = gbm.predict(xgb.DMatrix(X_valid[features]))
    # print(X_valid[features][:100])
    # print(check[:100])
    score = roc_auc_score(X_valid[target].values, check)
    print('Score: {}'.format(score))

    X_valid = append_items_info(X_valid, items)

    count = 0
    for i in range(len(X_valid[target].values)):
        if abs(X_valid[target].values[i] - check[i]) > 0.9:
            print(X_valid[target].values[i], check[i])
            if count > 100:
                break
            print_debug_data(out, X_valid, features, i, check[i], X_valid[target].values[i])
            count += 1
    print('Count critical: {} from {}'.format(count, len(check)))
    out.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号