def testModuleInfo_tuple(self):
# pylint: disable=not-callable
tf.reset_default_graph()
dumb = DumbModule(name="dumb_a")
ph_0 = tf.placeholder(dtype=tf.float32, shape=(1, 10,))
ph_1 = tf.placeholder(dtype=tf.float32, shape=(1, 10,))
dumb((ph_0, ph_1))
def check():
sonnet_collection = tf.get_default_graph().get_collection(
base_info.SONNET_COLLECTION_NAME)
connected_subgraph = sonnet_collection[0].connected_subgraphs[0]
self.assertIsInstance(connected_subgraph.inputs["inputs"], tuple)
self.assertIsInstance(connected_subgraph.outputs, tuple)
check()
_copy_default_graph()
check()
评论列表
文章目录