Util.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:MLPractices 作者: carefree0910 项目源码 文件源码
def load_frozen_graph(graph_dir, fix_nodes=True, entry=None, output=None):
        with gfile.FastGFile(graph_dir, "rb") as file:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(file.read())
            if fix_nodes:
                for node in graph_def.node:
                    if node.op == 'RefSwitch':
                        node.op = 'Switch'
                        for index in range(len(node.input)):
                            if 'moving_' in node.input[index]:
                                node.input[index] = node.input[index] + '/read'
                    elif node.op == 'AssignSub':
                        node.op = 'Sub'
                        if 'use_locking' in node.attr:
                            del node.attr['use_locking']
            tf.import_graph_def(graph_def, name="")
            if entry is not None:
                entry = tf.get_default_graph().get_tensor_by_name(entry)
            if output is not None:
                output = tf.get_default_graph().get_tensor_by_name(output)
            return entry, output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号