def __init__(self, input_args, dest_nodes = None):
super(TensorflowParser, self).__init__()
# load model files into Keras graph
from six import string_types as _string_types
if isinstance(input_args, _string_types):
model = TensorflowParser._load_meta(input_args)
elif isinstance(input_args, tuple):
model = TensorflowParser._load_meta(input_args[0])
self.ckpt_data = TensorflowParser._load_weights(input_args[1])
self.weight_loaded = True
if dest_nodes != None:
from tensorflow.python.framework.graph_util import extract_sub_graph
model = extract_sub_graph(model, dest_nodes.split(','))
# Build network graph
self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()
评论列表
文章目录