policy_opt_caffe.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:gps 作者: cbfinn 项目源码 文件源码
def __init__(self, hyperparams, dO, dU):
        config = copy.deepcopy(POLICY_OPT_CAFFE)
        config.update(hyperparams)

        PolicyOpt.__init__(self, config, dO, dU)

        self.batch_size = self._hyperparams['batch_size']

        if self._hyperparams['use_gpu']:
            caffe.set_device(self._hyperparams['gpu_id'])
            caffe.set_mode_gpu()
        else:
            caffe.set_mode_cpu()

        self.init_solver()
        self.caffe_iter = 0
        self.var = self._hyperparams['init_var'] * np.ones(dU)

        self.policy = CaffePolicy(self.solver.test_nets[0],
                                  self.solver.test_nets[1],
                                  self.var)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号