tensorflow_parser.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:MMdnn 作者: Microsoft 项目源码 文件源码
def __init__(self, input_args, dest_nodes = None):
        super(TensorflowParser, self).__init__()

        # load model files into Keras graph
        from six import string_types as _string_types
        if isinstance(input_args, _string_types):
            model = TensorflowParser._load_meta(input_args)
        elif isinstance(input_args, tuple):
            model = TensorflowParser._load_meta(input_args[0])
            self.ckpt_data = TensorflowParser._load_weights(input_args[1])
            self.weight_loaded = True

        if dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            model = extract_sub_graph(model, dest_nodes.split(','))

        # Build network graph
        self.tf_graph =  TensorflowGraph(model)
        self.tf_graph.build()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号