policy_opt_caffe.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:gps_superball_public 作者: young-geng 项目源码 文件源码
def __init__(self, hyperparams, dO, dU):
        config = copy.deepcopy(POLICY_OPT_CAFFE)
        config.update(hyperparams)

        PolicyOpt.__init__(self, config, dO, dU)

        self.batch_size = self._hyperparams['batch_size']

        if self._hyperparams['use_gpu']:
            caffe.set_device(self._hyperparams['gpu_id'])
            caffe.set_mode_gpu()
        else:
            caffe.set_mode_cpu()

        self.init_solver()
        # Load parameters from caffemodel file
        if 'init_net' in self._hyperparams:
            self.solver.net.copy_from(self._hyperparams['init_net'])

        self.caffe_iter = 0
        self.var = self._hyperparams['init_var'] * np.ones(dU)

        self.policy = CaffePolicy(self.solver.test_nets[0],
                                  self.solver.test_nets[1],
                                  self.var)

        self.policy.bias = None
        self.policy.scale = None
        if 'init_normalization' in self._hyperparams:
            with open(self._hyperparams['init_normalization']) as fin:
                normalzation_data = pickle.load(fin)
            self.policy.bias = normalzation_data['bias']
            self.policy.scale = normalzation_data['scale']
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号