def __init__(self, hyperparams, dO, dU):
config = copy.deepcopy(POLICY_OPT_CAFFE)
config.update(hyperparams)
PolicyOpt.__init__(self, config, dO, dU)
self.batch_size = self._hyperparams['batch_size']
if self._hyperparams['use_gpu']:
caffe.set_device(self._hyperparams['gpu_id'])
caffe.set_mode_gpu()
else:
caffe.set_mode_cpu()
self.init_solver()
# Load parameters from caffemodel file
if 'init_net' in self._hyperparams:
self.solver.net.copy_from(self._hyperparams['init_net'])
self.caffe_iter = 0
self.var = self._hyperparams['init_var'] * np.ones(dU)
self.policy = CaffePolicy(self.solver.test_nets[0],
self.solver.test_nets[1],
self.var)
self.policy.bias = None
self.policy.scale = None
if 'init_normalization' in self._hyperparams:
with open(self._hyperparams['init_normalization']) as fin:
normalzation_data = pickle.load(fin)
self.policy.bias = normalzation_data['bias']
self.policy.scale = normalzation_data['scale']
policy_opt_caffe.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录