knn_cv.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ML 作者: saurabhsuman47 项目源码 文件源码
def knn_cv(post_features, post_class, n_folds, n_neighbors, length_dataset = -1):

    if(length_dataset == -1):
        length_dataset = len(post_class)
    cv = KFold(n = length_dataset, n_folds = n_folds, shuffle = True)
    train_accuracy = []
    test_accuracy = []

    for train,test in cv:
        knn = neighbors.KNeighborsClassifier(n_neighbors = n_neighbors)
        knn.fit(post_features[train],post_class[train])
        train_accuracy.append(knn.score(post_features[train], post_class[train]))
        test_accuracy.append(knn.score(post_features[test], post_class[test]))

#    return (sum(train_accuracy)/n_folds), (sum(test_accuracy)/n_folds)
    return np.mean(train_accuracy), np.mean(test_accuracy)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号