xgb.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:kaggle-review 作者: daxiongshu 项目源码 文件源码
def cv(flags):
    X,y,Xt,yt,idx = build_feature(flags)

    params['verbose_eval'] = 10

    if '4c' in flags.task:
        y = np.argmax(to4c(onehot_encode(y)),axis=1)
        yt = np.argmax(to4c(onehot_encode(yt)),axis=1)
    params['num_class'] = np.max(y)+1
    model = xgb_model(params)
    print(X.shape,Xt.shape,y.shape,yt.shape)
    model.fit(X,y,Xt,yt,print_fscore=False)   
    yp = model.predict(Xt)
    s = pd.DataFrame(yp,columns=['class%d'%i for i in range(1,yp.shape[1]+1)])
    s['real'] = np.array(yt)
    s['ID'] = idx
    path = flags.data_path
    fold = flags.fold
    s.to_csv('%s/cv_%d.csv'%(path,fold),index=False)
    from utils.np_utils.utils import cross_entropy
    print(cross_entropy(yt,yp))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号