def __init__(self, solver_prototxt, db, output_dir, do_flip,
snapshot_path=None):
"""Initialize the SolverWrapper."""
self._output_dir = output_dir
self._solver = caffe.SGDSolver(solver_prototxt)
self._solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self._solver_param)
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
self._snapshot_prefix = self._solver_param.snapshot_prefix + infix + '_iter_'
if snapshot_path is not None:
print ('Loading snapshot weights from {:s}').format(snapshot_path)
self._solver.net.copy_from(snapshot_path)
snapshot_path = snapshot_path.split('/')[-1]
if snapshot_path.startswith(self._snapshot_prefix):
print 'Warning! Existing snapshots may be overriden by new snapshots!'
self._db = db
self._solver.net.layers[0].set_db(self._db, do_flip)
评论列表
文章目录