def online(X_org, y_org, test_x, test_uid):
n_folds = 5
verbose = True
shuffle = False
X = X_org
y = y_org
X_submission = test_x
if shuffle:
idx = np.random.permutation(y.size)
X = X[idx]
y = y[idx]
skf = list(StratifiedKFold(y, n_folds))
clfs = [
RandomForestClassifier().set_params(**INITIAL_PARAMS.get("RFC:one", {})),
ExtraTreesClassifier().set_params(**INITIAL_PARAMS.get("ETC:one", {})),
GradientBoostingClassifier().set_params(**INITIAL_PARAMS.get("GBC:one", {})),
LogisticRegression().set_params(**INITIAL_PARAMS.get("LR:one", {})),
xgb.XGBClassifier().set_params(**INITIAL_PARAMS.get("XGBC:two", {})),
xgb.XGBClassifier().set_params(**INITIAL_PARAMS.get("XGBC:one", {})),
]
print "Creating train and test sets for blending."
dataset_blend_train = np.zeros((X.shape[0], len(clfs)))
dataset_blend_test = np.zeros((X_submission.shape[0], len(clfs)))
for j, clf in enumerate(clfs):
print j, clf
dataset_blend_test_j = np.zeros((X_submission.shape[0], len(skf)))
for i, (train, test) in enumerate(skf):
print "Fold", i
X_train = X[train]
y_train = y[train]
X_test = X[test]
y_test = y[test]
clf.fit(X_train, y_train)
y_submission = clf.predict_proba(X_test)[:,1]
dataset_blend_train[test, j] = y_submission
dataset_blend_test_j[:, i] = clf.predict_proba(X_submission)[:,1]
dataset_blend_test[:,j] = dataset_blend_test_j.mean(1)
print "Blending."
# clf = LogisticRegression(C=2, penalty='l2', class_weight='balanced', n_jobs=-1)
clf = linear_model.RidgeCV(
alphas=np.linspace(0, 200), cv=LM_CV_NUM)
# clf = GradientBoostingClassifier(learning_rate=0.02, subsample=0.5, max_depth=6, n_estimators=100)
clf.fit(dataset_blend_train, y)
# y_submission = clf.predict_proba(dataset_blend_test)[:,1]
print clf.coef_, clf.intercept_
y_submission = clf.predict(dataset_blend_test) # for RidgeCV
print "Linear stretch of predictions to [0,1]"
y_submission = (y_submission - y_submission.min()) / (y_submission.max() - y_submission.min())
print "blend result"
save_submission(os.path.join(consts.SUBMISSION_PATH,
MODEL_NAME + '_' + strftime("%m_%d_%H_%M_%S", localtime()) + '.csv'),
test_uid, y_submission)
评论列表
文章目录