infQuant.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:nn-compression 作者: anithapk 项目源码 文件源码
def quantParam(): #pass saved n/w * suffix
     paramDict = {}
     minMaxDict = {}
     suffix = ["conv","_w:0"]
     with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./LenetParam.meta')
        saver.restore(sess,'./LenetParam')
        conv_wts = [v.name for v in tf.trainable_variables() if (v.name.startswith(suffix[0]) & v.name.endswith(suffix[1]))]
        lay_name = [v.name for v in tf.trainable_variables() if (v.name.endswith("_w:0") | v.name.endswith("_b:0"))]
        for v in lay_name:
            curLay = [a for a in tf.trainable_variables() if (a.name==v)]
            curWt = curLay[0].eval()
            if v in conv_wts:
                quantWt = tf.quantize_v2(curWt,tf.reduce_min(curWt),tf.reduce_max(curWt),tf.qint16,
                    mode="MIN_FIRST",name="quant32to16")
                chk = sess.run(quantWt)
                paramDict.update({v:chk.output})
                minMaxDict.update({v:[chk.output_min,chk.output_max]})
            else:
                chk = curWt
                paramDict.update({v:chk})
     print(paramDict.keys())
     print(minMaxDict.keys())
     return paramDict, minMaxDict
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号