lexent.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:emnlp2016 作者: stephenroller 项目源码 文件源码
def cv_trials(X, y, folds, model, hyper):
    N = len(y)

    cv_scores = []
    predictions = {
        'pred': np.zeros(N, dtype=np.bool),
        'proba': np.zeros(N),
        'foldno': np.zeros(N, dtype=np.int32) - 1,
    }
    pg = list(ParameterGrid(hyper))
    for foldno, (train, val, test) in enumerate(folds):
        train_X, train_y = X[train], y[train]
        val_X, val_y = X[val], y[val]
        test_X, test_y = X[test], y[test]
        best_params = None
        best_val_f1 = None
        for these_params in pg:
            model.set_params(**these_params)
            model.fit(train_X, train_y)
            this_val_f1 = metrics.f1_score(val_y, model.predict(val_X), average="weighted")
            if not best_params or this_val_f1 > best_val_f1:
                best_params = these_params
                best_val_f1 = this_val_f1
        if len(pg) > 1:
            model.set_params(**best_params)
            model.fit(train_X, train_y)
        train_f1 = metrics.f1_score(train_y, model.predict(train_X), average="weighted")

        preds_y = model.predict(test_X)
        predictions['pred'][test] = preds_y

        predictions['foldno'][test] = foldno

        fold_eval = {'f1': metrics.f1_score(test_y, preds_y, average="weighted"),
                      'p': metrics.precision_score(test_y, preds_y, average="weighted"),
                      'r': metrics.recall_score(test_y, preds_y, average="weighted"),
                      'a': metrics.accuracy_score(test_y, preds_y)}
        print "[%02d] Best hyper [train %.3f -> val %.3f -> test %.3f] %s" % (foldno, train_f1, best_val_f1, fold_eval['f1'], best_params)


        cv_scores.append(fold_eval)
        np.set_printoptions(suppress=True)

    # now we want to compute global evaluations, and consolidate metrics
    cv_scores = consolidate(cv_scores)

    preds_y = predictions['pred']
    pooled_eval = {'f1': metrics.f1_score(y, preds_y, average="weighted"),
                    'p': metrics.precision_score(y, preds_y, average="weighted"),
                    'r': metrics.recall_score(y, preds_y, average="weighted"),
                    'a': metrics.accuracy_score(y, preds_y)}

    return pooled_eval, predictions, cv_scores
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号