def __init__(self, solver_prototxt=None, path=None, base_lr=0.01, lr_policy="step",
gamma=0.1, stepsize=20000, momentum=0.9, weight_decay=0.0005,
regularization_type="L2", clip_gradients=None):
assert (path is not None) or (solver_prototxt is not None),\
'Need to specify either path or solver_prototxt.'
self._solver = caffe_pb2.SolverParameter()
if solver_prototxt is not None:
self._solver_prototxt = solver_prototxt
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self._solver)
elif path is not None:
self._solver_prototxt = osp.join(path, 'solver.prototxt')
# update proto object
self._solver.net = osp.join(path, 'train_val.prototxt')
self._solver.base_lr = base_lr
self._solver.lr_policy = lr_policy
self._solver.gamma = gamma
self._solver.stepsize = stepsize
self._solver.momentum = momentum
self._solver.weight_decay = weight_decay
self._solver.regularization_type = regularization_type
# caffe solver snapshotting is disabled
self._solver.snapshot = 0
# shut down caffe display
self._solver.display = 0
# shut down caffe validation
self._solver.test_iter.append(0)
self._solver.test_interval = 1000
if clip_gradients is not None:
self._solver.clip_gradients = clip_gradients
评论列表
文章目录