def solver_file(model_root, model_name):
s = caffe_pb2.SolverParameter() # ??solver??
s.train_net = model_root+'train.prototxt' # ??????????
s.test_net.append(model_root+'test.prototxt') # ????????????????????
# ?????test_interval????????
s.test_interval = 500
# ???????????????????
s.test_iter.append(100)
# ????????
s.max_iter = 10000
# ?????
s.base_lr = 0.01
# ???????
s.momentum = 0.9
# ??????????
s.weight_decay = 5e-4
# ?????????????fixed?step?exp?inv?multistep
# fixed: ??base_lr???
# step: ???????base_lr * gamma ^ (floor(iter / stepsize))???iter??????????
# exp: ???????base_lr * gamma ^ iter?
# inv: ???????power????????base_lr * (1 + gamma * iter) ^ (- power)?
# multistep: ???????stepvalue??????step???step??????????multistep????stepvalue????
# stepvalue?????
# poly: ?????????????base_lr (1 - iter/max_iter) ^ (power)?
# sigmoid: ?????sigmod?????base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))?
s.lr_policy = 'inv'
s.gamma = 0.0001
s.power = 0.75
s.display = 100 # ???display?????
s.snapshot = 5000 # ??????????
s.snapshot_prefix = model_root+model_name+'shapshot' # ????????????model???
s.type = 'SGD' # ???????????????????SGD?AdaDelta?AdaGrad?Adam?Nesterov?RMSProp
s.solver_mode = caffe_pb2.SolverParameter.GPU # ????????GPU?CPU
solver_file=model_root+'solver.prototxt' # ????solver???
with open(solver_file, 'w') as f:
f.write(str(s))
评论列表
文章目录