train.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:person_search 作者: ShuangLI59 项目源码 文件源码
def __init__(self, solver_prototxt, roidb, output_dir,
                 previous_state=None, pretrained_model=None):
        """Initialize the SolverWrapper."""
        self.output_dir = output_dir

        if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
            cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
            # RPN can only use precomputed normalization because there are no
            # fixed statistics to compute a priori
            assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED

        if cfg.TRAIN.BBOX_REG:
            print 'Computing bounding-box regression targets...'
            self.bbox_means, self.bbox_stds = \
                    rdl_roidb.add_bbox_regression_targets(roidb)
            print 'done'

        # Change the snapshot file name
        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)
        infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX) \
                if cfg.TRAIN.SNAPSHOT_INFIX != '' else ''
        filename = self.solver_param.snapshot_prefix + infix
        filename = os.path.join(self.output_dir, filename)
        self.solver_param.snapshot_prefix = filename
        print 'Change the snapshot prefix to', self.solver_param.snapshot_prefix
        with tempfile.NamedTemporaryFile(delete=False) as f:
            f.write(str(self.solver_param))
            solver_prototxt = f.name
            print 'Create temporary solver prototxt at', solver_prototxt

        self.solver = caffe.SGDSolver(solver_prototxt)
        if previous_state is not None:
            print ('Restoring solver state from {:s}').format(previous_state)
            self.solver.restore(previous_state)
        elif pretrained_model is not None:
            print ('Loading pretrained model '
                   'weights from {:s}').format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)

        self.solver.net.layers[0].set_roidb(roidb)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号