SolverWrapper.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:MNC 作者: daijifeng001 项目源码 文件源码
def __init__(self, solver_prototxt, roidb, maskdb, output_dir, imdb,
                 pretrained_model=None):
        self.output_dir = output_dir
        if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
                cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
            # RPN can only use precomputed normalization because there are no
            # fixed statistics to compute a priori
            assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED

        if cfg.TRAIN.BBOX_REG:
            if not cfg.CFM_MODE:
                print 'Computing bounding-box regression targets...'
                self.bbox_means, self.bbox_stds = add_bbox_regression_targets(roidb)
                print 'done'
            else:
                # Pre-defined mcg bbox_mean and bbox_std
                # We store them on disk to avoid disk level IO
                # multiple times (mcg boxes are stored on disk)
                mean_cache = './data/cache/mcg_bbox_mean.npy'
                std_cache = './data/cache/mcg_bbox_std.npy'
                roidb_dir = imdb._roidb_path
                if os.path.exists(mean_cache) and os.path.exists(std_cache):
                    self.bbox_means = np.load(mean_cache)
                    self.bbox_stds = np.load(std_cache)
                else:
                    self.bbox_means, self.bbox_stds = compute_mcg_mean_std(roidb_dir, imdb.num_classes)

        self.solver = caffe.SGDSolver(solver_prototxt)
        if pretrained_model is not None:
            print 'Loading pretrained model weights from {:s}'.format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)

        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)
        if not cfg.CFM_MODE:
            self.solver.net.layers[0].set_roidb(roidb)
            if cfg.MNC_MODE:
                self.solver.net.layers[0].set_maskdb(maskdb)
        else:
            self.solver.net.layers[0].set_image_info(imdb, self.bbox_means, self.bbox_stds)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号