def __init__(self, hyperparams, dO, dU):
config = copy.deepcopy(POLICY_OPT_CAFFE)
config.update(hyperparams)
PolicyOpt.__init__(self, config, dO, dU)
self.batch_size = self._hyperparams['batch_size']
if self._hyperparams['use_gpu']:
caffe.set_device(self._hyperparams['gpu_id'])
caffe.set_mode_gpu()
else:
caffe.set_mode_cpu()
self.init_solver()
self.caffe_iter = 0
self.var = self._hyperparams['init_var'] * np.ones(dU)
self.policy = CaffePolicy(self.solver.test_nets[0],
self.solver.test_nets[1],
self.var)
评论列表
文章目录