run_model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:website-fingerprinting 作者: AxelGoetz 项目源码 文件源码
def k_fold_validation(model, monitored_data, unmonitored_data, k, random_state=123):
    """
    Performs k fold validation on a model. During each fold, records all of the scoring in the `scoring_methods` module.

    @param model is a machine learning model that has the functions `fit(X, y)` and `predict(X)`
    @param monitored_data an array-like matrix that has the following structure `[(features, value)]`
    @param unmonitored_data is also an array-like object: [features]
    @param k is the amount of folds

    @return is a 2D array of scores, with the following structure `[{scoring_method: score}]` where the shape is `len(k)`
    """
    X, y = get_X_y(monitored_data, unmonitored_data)
    skf = StratifiedKFold(n_splits=k, random_state=random_state, shuffle=True)

    evaluations = []
    i = 1
    for train, test in skf.split(X, y):
        print("Starting split {}".format(i))
        X_train, X_test = X[train], X[test]
        y_train, y_test = y[train], y[test]

        print("Fitting data")
        model.fit(X_train, y_train)

        print("Predicting")
        prediction = model.predict(X_test)

        evaluations.append(scoring_methods.evaluate_model(prediction, y_test))

        print(evaluations[-1])

        i += 1

    return evaluations
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号