convertmodel.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:DmsMsgRcg 作者: bshao001 项目源码 文件源码
def __init__(self, config, graph, model_scope, model_dir, model_file):
        self.config = config

        frozen_model = os.path.join(model_dir, model_file)
        with tf.gfile.GFile(frozen_model, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        # This model_scope adds a prefix to all the nodes in the graph
        tf.import_graph_def(graph_def, input_map=None, return_elements=None,
                            name="{}/".format(model_scope))

        # Uncomment the two lines below to look for the names of all the operations in the graph
        # for op in graph.get_operations():
        #    print(op.name)

        # Using the lines commented above to look for the tensor name of the input node
        # Or you can figure it out in your original model, if you explicitly named it.
        self.input_tensor = graph.get_tensor_by_name("{}/input_1:0".format(model_scope))
        self.output_tensor = graph.get_tensor_by_name("{}/s1_output0:0".format(model_scope))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号