est_utils.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:gcForest 作者: kingfengji 项目源码 文件源码
def xgb_train(train_config, X_train, y_train, X_test, y_test):
    import xgboost as xgb
    LOGGER.info("X_train.shape={}, y_train.shape={}, X_test.shape={}, y_test.shape={}".format(
        X_train.shape, y_train.shape, X_test.shape, y_test.shape))
    param = train_config["param"]
    xg_train = xgb.DMatrix(X_train, label=y_train)
    xg_test = xgb.DMatrix(X_test, label=y_test)
    num_round = int(train_config["num_round"])
    watchlist = [(xg_train, 'train'), (xg_test, 'test')]
    try:
        bst = xgb.train(param, xg_train, num_round, watchlist)
    except KeyboardInterrupt:
        LOGGER.info("Canceld by user's Ctrl-C action")
        return
    y_pred = np.argmax(bst.predict(xg_test), axis=1)
    acc = 100. * np.sum(y_pred == y_test) / len(y_test)
    LOGGER.info("accuracy={}%".format(acc))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号