policy_net_script.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:WaNN 作者: TeoZosa 项目源码 文件源码
def K_FoldValidation(estimator, XMatrix, yVector, numFolds):
    numTrainingExamples = len(XMatrix)
    K = numFolds
    if K < 2:
        print("Error, K must be greater than or equal to 2")
        exit(-10)
    elif K > numTrainingExamples:
        print("Error, K must be less than or equal to the number of training examples")
        exit(-11)
    K_folds = model_selection.KFold(numTrainingExamples, K)

    for k, (train_index, test_index) in enumerate(K_folds):
        X_train, X_test = XMatrix[train_index], XMatrix[test_index]
        y_train, y_test = yVector[train_index], yVector[test_index]
        # Fit
        estimator.fit(X_train, y_train, logdir='')

        # Predict and score
        score = metrics.mean_squared_error(estimator.predict(X_test), y_test)
        print('Iteration {0:f} MSE: {1:f}'.format(k+1, score))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号