train.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:WPAL-network 作者: kyu-sz 项目源码 文件源码
def __init__(self, solver_prototxt, db, output_dir, do_flip,
                 snapshot_path=None):
        """Initialize the SolverWrapper."""
        self._output_dir = output_dir
        self._solver = caffe.SGDSolver(solver_prototxt)

        self._solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self._solver_param)

        infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
                 if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
        self._snapshot_prefix = self._solver_param.snapshot_prefix + infix + '_iter_'

        if snapshot_path is not None:
            print ('Loading snapshot weights from {:s}').format(snapshot_path)
            self._solver.net.copy_from(snapshot_path)

            snapshot_path = snapshot_path.split('/')[-1]
            if snapshot_path.startswith(self._snapshot_prefix):
                print 'Warning! Existing snapshots may be overriden by new snapshots!'

        self._db = db
        self._solver.net.layers[0].set_db(self._db, do_flip)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号