def __init__(self, solver_prototxt, pretrained_model=None):
"""Initialize the SolverWrapper."""
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model)
self.solver_param = caffe.io.caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
text_format.Merge(f.read(), self.solver_param)
if self.solver_param.solver_mode == 1:
caffe.set_mode_gpu()
caffe.set_device(params.gpu_id)
print 'Use GPU', params.gpu_id, 'to train'
else:
print 'Use CPU to train'
#initial python data layer
self.solver.net.layers[0].set_db()
评论列表
文章目录