infSparse.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:nn-compression 作者: anithapk 项目源码 文件源码
def quantParam(): #pass saved n/w * suffix
     paramDict = {}
     suffix = ["fc","_w:0"]
     with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./LenetParam.meta')
        saver.restore(sess,'./LenetParam')
        fc_wts = [v.name for v in tf.trainable_variables() if (v.name.startswith(suffix[0]) & v.name.endswith(suffix[1]))]
        lay_name = [v.name for v in tf.trainable_variables() if (v.name.endswith("_w:0") | v.name.endswith("_b:0"))]
        print(lay_name)
        for v in lay_name:
            print(v)
            curLay = [a for a in tf.trainable_variables() if (a.name==v)]
            curWt = curLay[0].eval()
            #if v in fc_wts:
            #    ind = tf.where(tf.not_equal(curWt, 0))
            #    sparse = tf.SparseTensor(ind, tf.gather_nd(curWt, ind), curLay[0].get_shape())
            #    tmp = sess.run(sparse)
            #else:
            tmp = curWt
            paramDict.update({v:tmp})      
     print(paramDict.keys())
     return paramDict
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号