base_info_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testModuleInfo_multiple_subgraph(self):
    # pylint: disable=not-callable
    tf.reset_default_graph()
    dumb = DumbModule(name="dumb_a")
    ph_0 = tf.placeholder(dtype=tf.float32, shape=(1, 10,))
    dumb(ph_0)
    with tf.name_scope("foo"):
      dumb(ph_0)
    def check():
      sonnet_collection = tf.get_default_graph().get_collection(
          base_info.SONNET_COLLECTION_NAME)
      self.assertEqual(len(sonnet_collection), 1)
      self.assertEqual(len(sonnet_collection[0].connected_subgraphs), 2)
      connected_subgraph_0 = sonnet_collection[0].connected_subgraphs[0]
      connected_subgraph_1 = sonnet_collection[0].connected_subgraphs[1]
      self.assertEqual(connected_subgraph_0.name_scope, "dumb_a")
      self.assertEqual(connected_subgraph_1.name_scope, "foo/dumb_a")
    check()
    _copy_default_graph()
    check()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号