def quantParam(): #pass saved n/w * suffix
paramDict = {}
suffix = ["fc","_w:0"]
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./LenetParam.meta')
saver.restore(sess,'./LenetParam')
fc_wts = [v.name for v in tf.trainable_variables() if (v.name.startswith(suffix[0]) & v.name.endswith(suffix[1]))]
lay_name = [v.name for v in tf.trainable_variables() if (v.name.endswith("_w:0") | v.name.endswith("_b:0"))]
print(lay_name)
for v in lay_name:
print(v)
curLay = [a for a in tf.trainable_variables() if (a.name==v)]
curWt = curLay[0].eval()
#if v in fc_wts:
# ind = tf.where(tf.not_equal(curWt, 0))
# sparse = tf.SparseTensor(ind, tf.gather_nd(curWt, ind), curLay[0].get_shape())
# tmp = sess.run(sparse)
#else:
tmp = curWt
paramDict.update({v:tmp})
print(paramDict.keys())
return paramDict
评论列表
文章目录