saver.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:rnnlab 作者: phueb 项目源码 文件源码
def calc_ba_data(self, probe_simmat, multi_probe_list):
        # make thr range
        probe_simmat_mean = np.asscalar(np.mean(probe_simmat))
        thr1 = max(0.0, round(min(0.9, round(probe_simmat_mean, 2)) - 0.1, 2))  # don't change
        thr2 = round(thr1 + 0.2, 2)
        # use bayes optimization to find best_thr
        if SaverConfigs.PRINT_BAYES_OPT:
            print('Finding best thresholds between {} and {} using bayesian-optimization...'.format(thr1, thr2))
        gp_params = {"alpha": 1e-5, "n_restarts_optimizer": 2}
        func_to_be_opt = partial(self.calc_probe_ba_list, probe_simmat, multi_probe_list, True)
        bo = BayesianOptimization(func_to_be_opt, {'thr': (thr1, thr2)}, verbose=SaverConfigs.PRINT_BAYES_OPT)
        bo.explore({'thr': [probe_simmat_mean]})
        bo.maximize(init_points=2, n_iter=SaverConfigs.NUM_BAYES_STEPS,
                    acq="poi", xi=0.001, **gp_params)  # smaller xi: exploitation
        best_thr = bo.res['max']['max_params']['thr']
        # calc probe_ba_list with best_thr
        probe_ba_list = self.calc_probe_ba_list(probe_simmat, multi_probe_list, False, best_thr)
        probe_ba_list = np.multiply(probe_ba_list, 100).tolist()
        # make avg_probe_ba_list
        avg_probe_ba_list = pd.DataFrame(
            data={'probe': multi_probe_list,
                  'probe_ba': probe_ba_list}).groupby('probe').mean()['probe_ba'].values.tolist()
        return probe_ba_list, avg_probe_ba_list
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号