test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:EasyLSTM 作者: cullengao 项目源码 文件源码
def test_all_metrics(model, data=None, usage_ratio=1):
    if data is None:
        X_train, y_train, X_test, y_test = read_data(usage_ratio=usage_ratio)
    else:
        # You ought to use the same training & testing set from your initial input.
        X_train, y_train, X_test, y_test = data

    y_pred = model.predict_classes(X_test)
    y_ground = np.argmax(y_test, axis=1)
    # y_proba = model.predict_proba(X_test)

    # overall_acc = (y_pred == y_ground).sum() * 1. / y_pred.shape[0]
    precision = sk.metrics.precision_score(y_ground, y_pred)
    recall = sk.metrics.recall_score(y_ground, y_pred)
    f1_score = sk.metrics.f1_score(y_ground, y_pred)
    # confusion_matrix = sk.metrics.confusion_matrix(y_ground, y_pred)
    # fpr, tpr, thresholds = sk.metrics.roc_curve(y_ground, y_pred)

    print "precision_score = ", precision
    print "recall_score = ", recall
    print "f1_score = ", f1_score

    # plot_roc_curve(y_test, y_proba)
    plot_confusion_matrix(y_ground, y_pred)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号