def solver_graph(self):
proto = caffe_pb2.SolverParameter()
proto.type = self.cmd.solver_type
if self.device is not None:
proto.solver_mode = caffe_pb2.SolverParameter.SolverMode.Value(
'GPU')
proto.device_id = self.device
else:
proto.solver_mode = caffe_pb2.SolverParameter.SolverMode.Value(
'CPU')
proto.lr_policy = 'fixed'
proto.base_lr = self.cmd.learning_rate
proto.momentum = self.cmd.momentum
proto.max_iter = int(2e9)
proto.random_seed = self.cmd.random_seed + self.rank
print('Setting seed ', proto.random_seed, file = sys.stderr)
proto.display = 1
batch = int(solver.cmd.input_shape[0] / solver.size)
if self.cmd.graph:
dir = os.path.dirname(os.path.realpath(__file__))
proto.net = dir + '/' + self.cmd.graph + '.prototxt'
else:
proto.train_net_param.MergeFrom(self.net_def(caffe.TRAIN))
proto.test_net_param.add().MergeFrom(self.net_def(caffe.TEST))
proto.test_iter.append(1)
proto.test_interval = 999999999 # cannot disable or set to 0
proto.test_initialization = False
return proto
评论列表
文章目录