def __init__(self, def_path, data_path, target_toolkit, input_shape=None, phase='test'):
self.layer_name_map = {}
self.data_injector = None
self.is_train_proto = False
self.input_shape = input_shape
if def_path is None:
if self.input_shape is None:
raise ConversionError('if the graph prototxt is not provided, the input shape should be provided')
self.input_shape = [1] + self.input_shape
def_path, self.data_injector = self.gen_prototxt_from_caffemodel(data_path, self.input_shape)
self.is_train_proto = True
else:
model = get_caffe_resolver().NetParameter()
with open(def_path, 'r') as f:
text_format.Merge(f.read(), model)
layers = model.layers or model.layer
if len([layer for layer in layers if NodeKind.map_raw_kind(layer.type) in LAYER_IN_TRAIN_PROTO]) > 0:
if self.input_shape is None:
raise ConversionError('the train_val.prototxt should be provided with the input shape')
self.input_shape = [1] + self.input_shape
self.is_train_proto = True
graph = GraphBuilder(def_path, self.input_shape, self.is_train_proto, phase).build()
if self.is_train_proto:
def_path = graph.prototxt
if data_path is not None:
graph = graph.transformed([
self.data_injector if self.data_injector else DataInjector(def_path, data_path), # Load and associate learned parameters
BatchNormScaleBiasFuser(),
BatchNormPreprocessor() # Pre-process batch normalization data
])
target_toolkit = target_toolkit.lower()
if target_toolkit not in ('caffe', 'caffe2'):
graph = graph.transformed([DataReshaper({ # Reshape the parameters to TensorFlow's ordering
NodeKind.Convolution: (2, 3, 1, 0), # (c_o, c_i, h, w) -> (h, w, c_i, c_o)
NodeKind.InnerProduct: (1, 0) # (c_o, c_i) -> (c_i, c_o)
}),
ParameterNamer() # Convert parameters to dictionaries
])
self.graph = graph
# self.graph = NodeRenamer()(graph)
print_stderr(self.graph)
评论列表
文章目录