def _load_solver(self, solver_params, model_params):
""" Load solver """
solver_path = solver_params.path
if model_params.max_rounds > 1:
if self._cur_round > 0:
solver_path = osp.join(osp.dirname(solver_path), 'round_{}'.format(self._cur_round))
else:
solver_path = osp.join(solver_path, 'round_{}'.format(self._cur_round))
solver_params.set_path(solver_path)
model = model_params.model
if model is not None:
model.to_proto(solver_path, deploy=False)
model.to_proto(solver_path, deploy=True)
print 'Model files saved at {}'.format(solver_path)
return caffe.SGDSolver(solver_params.to_proto())
评论列表
文章目录