BN-absorber-enet.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ENet 作者: TimoSaemann 项目源码 文件源码
def bn_absorber_prototxt(model):

    # load the prototxt file as a protobuf message
    with open(model) as k:
        str1 = k.read()
    msg1 = caffe_pb2.NetParameter()
    text_format.Merge(str1, msg1)

    # search for bn layer and remove them
    for i, l in enumerate(msg1.layer):
        if l.type == "BN":
            if msg1.layer[i].name == 'bn0_1':
                continue
            if msg1.layer[i - 1].type == 'Deconvolution':
                continue
            msg1.layer.remove(l)
            msg1.layer[i].bottom.append(msg1.layer[i-1].top[0])

            if len(msg1.layer[i].bottom) == 2:
                msg1.layer[i].bottom.remove(msg1.layer[i].bottom[0])
            elif len(msg1.layer[i].bottom) == 3:
                if ('bn' in msg1.layer[i].bottom[0]) is True:  # to remove just the layers with 'bn' in the name
                    msg1.layer[i].bottom.remove(msg1.layer[i].bottom[0])
                elif ('bn' in msg1.layer[i].bottom[1]) is True:
                    msg1.layer[i].bottom.remove(msg1.layer[i].bottom[1])
                else:
                    raise Exception("no bottom blob with name 'bn' present in {} layer".format(msg1.layer[i]))

            else:
                raise Exception("bn absorber does not support more than 2 input blobs for layer {}"
                                .format(msg1.layer[i]))

            if msg1.layer[i].type == 'Upsample':
                temp = msg1.layer[i].bottom[0]
                msg1.layer[i].bottom[0] = msg1.layer[i].bottom[1]
                msg1.layer[i].bottom[1] = temp
                # l.bottom.append(l.top[0]) #msg1.layer[i-1].top

    return msg1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号