def test_lr_on_data(X_train, y_train, X_validate, y_validate, X_test, y_test):
y_train_flatten = list(itertools.chain(*y_train))
# Train LR Model
lr = LogisticRegression(solver='lbfgs')
lr.fit(X_train, y_train_flatten)
# Test model on validation set
predictions_val = lr.predict_proba(X_validate)
predictions_val = array([i[-1] for i in predictions_val])
best_threshold_validate = find_threshold_logistic(y_validate, predictions_val, predictions_val)
precision_val, recall_val, f1_val = evaluate_with_threshold(y_validate, predictions_val, predictions_val,
best_threshold_validate)
globals.logger.info("Found threshold: %f. Precision/recall/f1 over validation set: %f/%f/%f" %
(best_threshold_validate, precision_val, recall_val, f1_val))
# Test model on test set
predictions_test = lr.predict_proba(X_test)
predictions_test = array([i[-1] for i in predictions_test])
best_threshold_test = find_threshold_logistic(y_test, predictions_test, predictions_test, verbose=True)
precision, recall, f1 = evaluate_with_threshold(y_test, predictions_test, predictions_test, best_threshold_test)
globals.logger.info("Found threshold: %f. Precision/recall/f1 over test set: %f/%f/%f" %
(best_threshold_test, precision, recall, f1))
return precision, recall, f1
评论列表
文章目录