RBFSolver.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Steal-ML 作者: ftramer 项目源码 文件源码
def do(self, n_pts):
        X, y = self.collect_pts(n_pts)

        print 'done collecting points'

        rbf_map = RBFSampler(n_components=n_pts, random_state=1)
        solver = HyperSolver(p=self.POS, n=self.NEG)
        rbf_solver = pipeline.Pipeline([("mapper", rbf_map),
                                        ("solver", solver)])

        gamma_range = np.logspace(-15, 6, 22, base=2)
        param_grid = dict(mapper__gamma=gamma_range)
        cv = StratifiedShuffleSplit(y, n_iter=5, test_size=0.2, random_state=1)
        grid = GridSearchCV(rbf_solver, param_grid=param_grid, cv=cv, n_jobs=8)
        grid.fit(X, y)

        scores = [x[1] for x in grid.grid_scores_]
        scores = np.array(scores).reshape(len(gamma_range))
        plt.figure(figsize=(8, 6))
        plt.plot(gamma_range, scores)

        plt.xlabel('gamma')
        plt.ylabel('score')
        plt.title('Validation accuracy (RTiX, %s)' % os.path.basename(self.name))
        plt.savefig(self.name + '-SLViF-grid-npts=%d.pdf' % n_pts)

        # final train
        g = grid.best_params_['mapper__gamma']
        print 'best parameters are g=%f' % g
        rbf_svc2 = grid.best_estimator_
        y_pred = rbf_svc2.predict(self.Xt)
        print 'SCORE: %f' % sm.accuracy_score(self.Yt, y_pred)
        return grid.best_score_, sm.accuracy_score(self.Yt, y_pred)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号