def __init__(self, solver_prototxt, roidb, output_dir,
previous_state=None, pretrained_model=None):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir
if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
# RPN can only use precomputed normalization because there are no
# fixed statistics to compute a priori
assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
if cfg.TRAIN.BBOX_REG:
print 'Computing bounding-box regression targets...'
self.bbox_means, self.bbox_stds = \
rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'
# Change the snapshot file name
self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX) \
if cfg.TRAIN.SNAPSHOT_INFIX != '' else ''
filename = self.solver_param.snapshot_prefix + infix
filename = os.path.join(self.output_dir, filename)
self.solver_param.snapshot_prefix = filename
print 'Change the snapshot prefix to', self.solver_param.snapshot_prefix
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(str(self.solver_param))
solver_prototxt = f.name
print 'Create temporary solver prototxt at', solver_prototxt
self.solver = caffe.SGDSolver(solver_prototxt)
if previous_state is not None:
print ('Restoring solver state from {:s}').format(previous_state)
self.solver.restore(previous_state)
elif pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model)
self.solver.net.layers[0].set_roidb(roidb)
评论列表
文章目录