def __init__(self, solver_prototxt, roidb, output_dir,
nccl_uid, rank, bbox_means=None, bbox_stds=None,
pretrained_model=None):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir
self.rank = rank
if cfg.TRAIN.BBOX_REG:
self.bbox_means, self.bbox_stds = bbox_means, bbox_stds
if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
# RPN can only use precomputed normalization because there are no
# fixed statistics to compute a priori
assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
self.solver = caffe.SGDSolver(solver_prototxt)
assert caffe.solver_count() * cfg.TRAIN.IMS_PER_BATCH * self.solver.param.iter_size == \
cfg.TRAIN.REAL_BATCH_SIZE, "{} vs {}". \
format(caffe.solver_count() * cfg.TRAIN.IMS_PER_BATCH * self.solver.param.iter_size, cfg.TRAIN.REAL_BATCH_SIZE)
if pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model)
nccl = caffe.NCCL(self.solver, nccl_uid)
nccl.bcast()
self.solver.add_callback(nccl)
assert self.solver.param.layer_wise_reduce
if self.solver.param.layer_wise_reduce:
self.solver.net.after_backward(nccl)
self.nccl = nccl # hold the reference to nccl
self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
self.solver.net.layers[0].set_roidb(roidb)
train.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录