caffe_net.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:pytorch-yolo2 作者: marvis 项目源码 文件源码
def load_weigths_from_caffe(self, protofile, caffemodel):
        caffe.set_mode_cpu()
        net = caffe.Net(protofile, caffemodel, caffe.TEST)
        for name, layer in self.models.items():
            if isinstance(layer, nn.Conv2d):
                caffe_weight = net.params[name][0].data
                layer.weight.data = torch.from_numpy(caffe_weight)
                if len(net.params[name]) > 1:
                    caffe_bias = net.params[name][1].data
                    layer.bias.data = torch.from_numpy(caffe_bias)
                continue
            if isinstance(layer, nn.BatchNorm2d):
                caffe_means = net.params[name][0].data
                caffe_var = net.params[name][1].data
                layer.running_mean = torch.from_numpy(caffe_means)
                layer.running_var = torch.from_numpy(caffe_var)
                # find the scale layer
                top_name_of_bn = self.layer_map_to_top[name][0]
                scale_name = ''
                for caffe_layer in self.net_info['layers']:
                    if caffe_layer['type'] == 'Scale' and caffe_layer['bottom'][0] == top_name_of_bn:
                        scale_name = caffe_layer['name']
                        break
                if scale_name != '':
                    caffe_weight = net.params[scale_name][0].data
                    layer.weight.data = torch.from_numpy(caffe_weight)
                    if len(net.params[name]) > 1:
                        caffe_bias = net.params[scale_name][1].data
                        layer.bias.data = torch.from_numpy(caffe_bias)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号