tensorflow.train.import_meta_graph不起作用?
我尝试简单地保存和恢复图形,但是最简单的示例无法按预期工作(在没有CUDA且使用python 2.7或3.5.2的Linux
64上使用0.9.0或0.10.0版完成此操作)
首先,我像这样保存图形:
import tensorflow as tf
v1 = tf.placeholder('float32')
v2 = tf.placeholder('float32')
v3 = tf.mul(v1,v2)
c1 = tf.constant(22.0)
v4 = tf.add(v3,c1)
sess = tf.Session()
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3})
g1 = tf.train.export_meta_graph("file")
## alternately I also tried:
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])
这将创建一个非空文件“文件”,并将g1设置为看起来像正确的图形定义的文件。
然后,我尝试还原此图:
import tensorflow as tf
g=tf.train.import_meta_graph("file")
这可以正常工作,但不会返回任何内容。
任何人都可以提供必要的代码来简单地保存“ v4”图形并完全还原它,以便在新会话中运行该图形会产生相同的结果吗?
-
要重用
MetaGraphDef
,您需要在原始图中记录有趣的张量的名称。例如,在第一程序中,设置明确name
的定义中的参数v1
,v2
和v4
:v1 = tf.placeholder(tf.float32, name="v1") v2 = tf.placeholder(tf.float32, name="v2") # ... v4 = tf.add(v3, c1, name="v4")
然后,您可以在调用中使用原始图中的张量的字符串名称
sess.run()
。例如,以下代码片段应该起作用:import tensorflow as tf _ = tf.train.import_meta_graph("./file") sess = tf.Session() result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
另外,您可以使用
tf.get_default_graph().get_tensor_by_name()
获取tf.Tensor
感兴趣的张量的对象,然后将其传递给sess.run()
:import tensorflow as tf _ = tf.train.import_meta_graph("./file") g = tf.get_default_graph() v1 = g.get_tensor_by_name("v1:0") v2 = g.get_tensor_by_name("v2:0") v4 = g.get_tensor_by_name("v4:0") sess = tf.Session() result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
更新 :根据注释中的讨论,这里是保存和加载(包括保存变量内容)的完整示例。这说明了通过
vx
在单独的操作中将变量的值加倍来保存变量。保存:
import tensorflow as tf v1 = tf.placeholder(tf.float32, name="v1") v2 = tf.placeholder(tf.float32, name="v2") v3 = tf.mul(v1, v2) vx = tf.Variable(10.0, name="vx") v4 = tf.add(v3, vx, name="v4") saver = tf.train.Saver([vx]) sess = tf.Session() sess.run(tf.initialize_all_variables()) sess.run(vx.assign(tf.add(vx, vx))) result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) print(result) saver.save(sess, "./model_ex1")
恢复:
import tensorflow as tf saver = tf.train.import_meta_graph("./model_ex1.meta") sess = tf.Session() saver.restore(sess, "./model_ex1") result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) print(result)
最重要的是,为了使用保存的模型,您必须记住至少某些节点的名称(例如,训练操作,输入占位符,评估张量等)。该
MetaGraphDef
专卖店在训练中所包含的模型,并有助于从检查点恢复这些,但你必须重建张量的变量列表/使用/自己评估模型操作。