kernel_selection2.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:DataMining 作者: lidalei 项目源码 文件源码
def hot(X, y):


    C_range = np.logspace(-15, 15, 31,base = 2.0)
    gamma_range = np.logspace(-15, 15, 31, base = 2.0)

#     param_grid = dict(gamma=gamma_range, C=C_range)
#     cv = StratifiedShuffleSplit(y, n_iter=10, test_size=0.2, random_state=42)
    roc_auc_scorer = get_scorer("roc_auc")
    scores = []
    for C in C_range:
        for gamma in gamma_range:
            auc_scorer = []
            for train, test in KFold(n=len(X), n_folds=10, random_state=42):
                rbf_svc = svm.SVC(C=C, kernel='rbf', gamma=gamma, probability=True)
                X_train, y_train = X[train], y[train]
                X_test, y_test = X[test], y[test]
                rbf_clf = rbf_svc.fit(X_train, y_train)
                auc_scorer.append(roc_auc_scorer(rbf_clf, X_test, y_test))
            scores.append(np.mean(auc_scorer))
#     grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv)
#     grid.fit(X, y)
#     scores = [x[1] for x in grid.grid_scores_]
    scores = np.array(scores).reshape(len(C_range), len(gamma_range))
    print scores
    plt.figure(figsize=(15, 12))
    plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
    plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot,
               norm=MidpointNormalize(vmin=0.2, midpoint=0.92))
    plt.xlabel('gamma')
    plt.ylabel('C')
    plt.colorbar()
    plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=90)
    plt.yticks(np.arange(len(C_range)), C_range)
    plt.title('AUC')
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号