LinearTrainer.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Steal-ML 作者: ftramer 项目源码 文件源码
def grid_search(self):
        C_range = np.logspace(-5, 15, 21, base=2)
        param_grid = dict(C=C_range)
        cv = StratifiedShuffleSplit(self.y_ex, n_iter=5, test_size=0.2, random_state=42)
        grid = GridSearchCV(LinearSVC(dual=False, max_iter=10000), param_grid=param_grid,
                            cv=cv,
                            n_jobs=1, verbose=0)

        logger.info('start grid search for Linear')
        grid.fit(self.X_ex, self.y_ex)
        logger.info('end grid search for Linear')

        scores = [x[1] for x in grid.grid_scores_]

        # final train
        rbf_svc2 = grid.best_estimator_

        pred_train = rbf_svc2.predict(self.X_ex)
        pred_val = rbf_svc2.predict(self.val_x)
        pred_test = rbf_svc2.predict(self.test_x)

        r = Result(self.name + ' (X)', 'Linear', len(self.X_ex),
                   sm.accuracy_score(self.y_ex, pred_train),
                   sm.accuracy_score(self.val_y, pred_val),
                   sm.accuracy_score(self.test_y, pred_test))
        return r
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号