input.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:spark-deep-learning 作者: databricks 项目源码 文件源码
def fromGraphDef(cls, graph_def, feed_names, fetch_names):
        """
        Construct a TFInputGraph from a tf.GraphDef object.

        :param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and
                           computation units of the TensorFlow graph.
        :param feed_names: list, names of the input tensors.
        :param fetch_names: list, names of the output tensors.
        """
        assert isinstance(graph_def, tf.GraphDef), \
            ('expect tf.GraphDef type but got', type(graph_def))

        graph = tf.Graph()
        with tf.Session(graph=graph) as sess:
            tf.import_graph_def(graph_def, name='')
            return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
                                             fetch_names=fetch_names)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号