experiment.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:question-answering 作者: emorynlp 项目源码 文件源码
def test_lr_on_data(X_train, y_train, X_validate, y_validate, X_test, y_test):
    y_train_flatten = list(itertools.chain(*y_train))

    # Train LR Model
    lr = LogisticRegression(solver='lbfgs')
    lr.fit(X_train, y_train_flatten)

    # Test model on validation set
    predictions_val = lr.predict_proba(X_validate)
    predictions_val = array([i[-1] for i in predictions_val])
    best_threshold_validate = find_threshold_logistic(y_validate, predictions_val, predictions_val)
    precision_val, recall_val, f1_val = evaluate_with_threshold(y_validate, predictions_val, predictions_val,
                                                                best_threshold_validate)
    globals.logger.info("Found threshold: %f. Precision/recall/f1 over validation set: %f/%f/%f" %
                        (best_threshold_validate, precision_val, recall_val, f1_val))

    # Test model on test set
    predictions_test = lr.predict_proba(X_test)
    predictions_test = array([i[-1] for i in predictions_test])
    best_threshold_test = find_threshold_logistic(y_test, predictions_test, predictions_test, verbose=True)
    precision, recall, f1 = evaluate_with_threshold(y_test, predictions_test, predictions_test, best_threshold_test)
    globals.logger.info("Found threshold: %f. Precision/recall/f1 over test set: %f/%f/%f" %
                        (best_threshold_test, precision, recall, f1))

    return precision, recall, f1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号