model_transform.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:caffe-tensorflow 作者: blankWorld 项目源码 文件源码
def get_caffe_variables(self,net_proto,net_model = None,bn_name = ''):
        " This function get caffe variables"
        caffe.set_mode_cpu()
        self.blob_dict={}
        if net_model is not None:
            self.net_caffe = caffe.Net(net_proto,net_model,caffe.TEST)

        else:
            self.net_caffe = caffe.Net(net_proto,caffe.TEST)
        # caffe net params layer_name w b
        # bn_name : caffe bn layer name include bn_name
        # Note: we must match tf_variables name and caffe params name
        # so we modifiy caffe params name and save in bolb_dict
        for layer_name,param in self.net_caffe.params.items():
            param_len = len(param)
            # find batch_normalization name must has 'bn_name'
            # your can modify it
            if param_len == 3 and layer_name.find(bn_name) >= 0:
                scale_factor = 1.0 / param[2].data[0]
                mean = param[0].data * scale_factor
                variance = param[1].data *scale_factor
                name = str(layer_name) + "/weights:0"
                self.blob_dict[name] = mean  
                name = str(layer_name) + "/biases:0"
                self.blob_dict[name] = variance  
            elif param_len == 2:
                name = str(layer_name) + "/weights:0" 
                self.blob_dict[name] = param[0].data
                name = str(layer_name) + "/biases:0" 
                self.blob_dict[name] = param[1].data
            elif param_len == 1:
                name = str(layer_name) + "/weights:0" 
                self.blob_dict[name] = param[0].data
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号