cross_validation.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ottertune 作者: cmu-db 项目源码 文件源码
def combine_rmse_gpvar(grid_scores, w_rmse=0.8, w_gpvar=0.2):
    from sklearn.preprocessing import minmax_scale

    # Scale rmses, gpvars to (0,1)
    scaled_scores = np.empty((len(grid_scores), 2))
    for i,scores in enumerate(grid_scores):
        scaled_scores[i,0] = scores.mean_scores[0]
        scaled_scores[i,1] = scores.mean_scores[1]
    rmse_sort_indices = np.argsort(scaled_scores[:,0])
    gpvar_sort_indices = np.argsort(scaled_scores[:,1])
    scaled_scores = minmax_scale(scaled_scores)
    combined_scores = w_rmse*scaled_scores[:,0] + w_gpvar*scaled_scores[:,1]
    comb_sort_indices = np.argsort(combined_scores)
    return CombinedScore(combined_scores,
                         scaled_scores,
                         comb_sort_indices,
                         rmse_sort_indices,
                         gpvar_sort_indices)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号