calibration_utils.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:introspective 作者: numeristical 项目源码 文件源码
def train_and_calibrate_cv(model, X_tr, y_tr, cv=5):
    y_pred_xval = np.zeros(len(y_tr))
    skf = cross_validation.StratifiedKFold(y_tr, n_folds=cv,shuffle=True)
    i = 0;
    for train, test in skf:
        i = i+1
        print("training fold {} of {}".format(i, cv))
        X_train_xval = np.array(X_tr)[train,:]
        X_test_xval = np.array(X_tr)[test,:]
        y_train_xval = np.array(y_tr)[train]
        # We could also copy the model first and then fit it
        model_copy = clone(model)
        model_copy.fit(X_train_xval,y_train_xval)
        y_pred_xval[test]=model.predict_proba(X_test_xval)[:,1]
    print("training full model")
    model_copy = clone(model)
    model_copy.fit(X_tr,y_tr)
    print("calibrating function")
    calib_func = prob_calibration_function(y_tr, y_pred_xval)
    return model_copy, calib_func
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号