convToSparse.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:nn-compression 作者: anithapk 项目源码 文件源码
def conv2SaveSparse(chkPt,outDir):
    #conver weights to sparse format
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(chkPt+".meta")
        saver.restore(sess,"./"+chkPt)
        lay_name = [v.name for v in tf.trainable_variables() if (v.name.endswith("_w:0"))]
        for v in lay_name:
            print(v)
            curLay = [a for a in tf.trainable_variables() if (a.name==v)]
            wt = curLay[0].eval()
            print("np:",np.where(wt!=0)[0].shape)
            ind = tf.where(tf.not_equal(wt, 0))
            sparse = tf.SparseTensor(ind, tf.gather_nd(wt, ind), curLay[0].get_shape())
            tmp = sess.run(sparse)
            valName = outDir+v+"spVal.npy"
            print(valName)
            with open(valName,'wb') as f:
                np.save(f,tmp[1])
            valName = outDir+v+"spMatSize.npy"
            print(valName)
            with open(valName,'wb') as f:
                np.save(f,tmp[2])
            print("tmp",[tmp[0].shape,tmp[0].dtype,tmp[1].shape,tmp[2]])
            indMat64 = tmp[0]
            castIndMat64 = tf.cast(indMat64,tf.uint16)
            indMat16 = sess.run(castIndMat64)
            print("intMat16:",[indMat16.shape,indMat16.dtype])
            valName = outDir+v+"spInd16.npy"
            print(valName)
            with open(valName,'wb') as f:
                np.save(f,tmp[0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号