def fromGraphDef(cls, graph_def, feed_names, fetch_names):
"""
Construct a TFInputGraph from a tf.GraphDef object.
:param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and
computation units of the TensorFlow graph.
:param feed_names: list, names of the input tensors.
:param fetch_names: list, names of the output tensors.
"""
assert isinstance(graph_def, tf.GraphDef), \
('expect tf.GraphDef type but got', type(graph_def))
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name='')
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
fetch_names=fetch_names)
评论列表
文章目录