tensorflow.train.import_meta_graph不起作用?

发布于 2021-01-29 17:56:00

我尝试简单地保存和恢复图形,但是最简单的示例无法按预期工作(在没有CUDA且使用python 2.7或3.5.2的Linux
64上使用0.9.0或0.10.0版完成此操作)

首先,我像这样保存图形:

import tensorflow as tf
v1 = tf.placeholder('float32') 
v2 = tf.placeholder('float32')
v3 = tf.mul(v1,v2)
c1 = tf.constant(22.0)
v4 = tf.add(v3,c1)
sess = tf.Session()
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3})
g1 = tf.train.export_meta_graph("file")
## alternately I also tried:
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])

这将创建一个非空文件“文件”,并将g1设置为看起来像正确的图形定义的文件。

然后,我尝试还原此图:

import tensorflow as tf
g=tf.train.import_meta_graph("file")

这可以正常工作,但不会返回任何内容。

任何人都可以提供必要的代码来简单地保存“ v4”图形并完全还原它,以便在新会话中运行该图形会产生相同的结果吗?

关注者
0
被浏览
144
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    要重用MetaGraphDef,您需要在原始图中记录有趣的张量的名称。例如,在第一程序中,设置明确name的定义中的参数v1v2v4

    v1 = tf.placeholder(tf.float32, name="v1")
    v2 = tf.placeholder(tf.float32, name="v2")
    # ...
    v4 = tf.add(v3, c1, name="v4")
    

    然后,您可以在调用中使用原始图中的张量的字符串名称sess.run()。例如,以下代码片段应该起作用:

    import tensorflow as tf
    _ = tf.train.import_meta_graph("./file")
    
    sess = tf.Session()
    result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
    

    另外,您可以使用tf.get_default_graph().get_tensor_by_name()获取tf.Tensor感兴趣的张量的对象,然后将其传递给sess.run()

    import tensorflow as tf
    _ = tf.train.import_meta_graph("./file")
    g = tf.get_default_graph()
    
    v1 = g.get_tensor_by_name("v1:0")
    v2 = g.get_tensor_by_name("v2:0")
    v4 = g.get_tensor_by_name("v4:0")
    
    sess = tf.Session()
    result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
    

    更新 :根据注释中的讨论,这里是保存和加载(包括保存变量内容)的完整示例。这说明了通过vx在单独的操作中将变量的值加倍来保存变量。

    保存:

    import tensorflow as tf
    v1 = tf.placeholder(tf.float32, name="v1") 
    v2 = tf.placeholder(tf.float32, name="v2")
    v3 = tf.mul(v1, v2)
    vx = tf.Variable(10.0, name="vx")
    v4 = tf.add(v3, vx, name="v4")
    saver = tf.train.Saver([vx])
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    sess.run(vx.assign(tf.add(vx, vx)))
    result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
    print(result)
    saver.save(sess, "./model_ex1")
    

    恢复:

    import tensorflow as tf
    saver = tf.train.import_meta_graph("./model_ex1.meta")
    sess = tf.Session()
    saver.restore(sess, "./model_ex1")
    result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
    print(result)
    

    最重要的是,为了使用保存的模型,您必须记住至少某些节点的名称(例如,训练操作,输入占位符,评估张量等)。该MetaGraphDef专卖店在训练中所包含的模型,并有助于从检查点恢复这些,但你必须重建张量的变量列表/使用/自己评估模型操作。



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看