def generate_solver_proto(solver_fn, model_fn, trainOpts):
from caffe.proto import caffe_pb2
solver = caffe_pb2.SolverParameter()
solver.net = model_fn
if trainOpts.num_lr_decays > 0:
solver.lr_policy = 'step'
solver.gamma = trainOpts.lr_decay_factor
solver.stepsize = int(trainOpts.iters/(trainOpts.num_lr_decays+1))
else:
solver.lr_policy = 'fixed'
solver.base_lr = trainOpts.init_lr
solver.max_iter = trainOpts.iters
solver.display = 20
solver.momentum = 0.9
solver.weight_decay = trainOpts.paramReg
solver.test_state.add()
solver.test_state.add()
solver.test_state[0].stage.append('TestRecognition')
solver.test_state[1].stage.append('TestZeroShot')
solver.test_iter.extend([20, 20])
solver.test_interval = 100
solver.snapshot = 5000
solver.snapshot_prefix = os.path.splitext(model_fn)[0]
with open(solver_fn, 'w') as f:
f.write(str(solver))
评论列表
文章目录