def testModuleInfo_multiple_subgraph(self):
# pylint: disable=not-callable
tf.reset_default_graph()
dumb = DumbModule(name="dumb_a")
ph_0 = tf.placeholder(dtype=tf.float32, shape=(1, 10,))
dumb(ph_0)
with tf.name_scope("foo"):
dumb(ph_0)
def check():
sonnet_collection = tf.get_default_graph().get_collection(
base_info.SONNET_COLLECTION_NAME)
self.assertEqual(len(sonnet_collection), 1)
self.assertEqual(len(sonnet_collection[0].connected_subgraphs), 2)
connected_subgraph_0 = sonnet_collection[0].connected_subgraphs[0]
connected_subgraph_1 = sonnet_collection[0].connected_subgraphs[1]
self.assertEqual(connected_subgraph_0.name_scope, "dumb_a")
self.assertEqual(connected_subgraph_1.name_scope, "foo/dumb_a")
check()
_copy_default_graph()
check()
评论列表
文章目录