transformer.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:MMdnn 作者: Microsoft 项目源码 文件源码
def __init__(self, def_path, data_path, target_toolkit, input_shape=None, phase='test'):
        self.layer_name_map = {}
        self.data_injector = None
        self.is_train_proto = False
        self.input_shape = input_shape
        if def_path is None:
            if self.input_shape is None:
                raise ConversionError('if the graph prototxt is not provided, the input shape should be provided')
            self.input_shape = [1] + self.input_shape
            def_path, self.data_injector = self.gen_prototxt_from_caffemodel(data_path, self.input_shape)
            self.is_train_proto = True
        else:
            model = get_caffe_resolver().NetParameter()
            with open(def_path, 'r') as f:
                text_format.Merge(f.read(), model)
            layers = model.layers or model.layer
            if len([layer for layer in layers if NodeKind.map_raw_kind(layer.type) in LAYER_IN_TRAIN_PROTO]) > 0:
                if self.input_shape is None:
                    raise ConversionError('the train_val.prototxt should be provided with the input shape')
                self.input_shape = [1] + self.input_shape
                self.is_train_proto = True
        graph = GraphBuilder(def_path, self.input_shape, self.is_train_proto, phase).build()
        if self.is_train_proto:
            def_path = graph.prototxt
        if data_path is not None:
            graph = graph.transformed([
                self.data_injector if self.data_injector else DataInjector(def_path, data_path), # Load and associate learned parameters
                BatchNormScaleBiasFuser(),
                BatchNormPreprocessor() # Pre-process batch normalization data
            ])
            target_toolkit = target_toolkit.lower()
            if target_toolkit not in ('caffe', 'caffe2'):
                graph = graph.transformed([DataReshaper({ # Reshape the parameters to TensorFlow's ordering
                    NodeKind.Convolution: (2, 3, 1, 0), # (c_o, c_i, h, w) -> (h, w, c_i, c_o)
                    NodeKind.InnerProduct: (1, 0) # (c_o, c_i) -> (c_i, c_o)
                }),
                    ParameterNamer() # Convert parameters to dictionaries
                ])
        self.graph = graph
        #  self.graph = NodeRenamer()(graph)
        print_stderr(self.graph)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号