def testModuleInfo_namedtuple(self):
# pylint: disable=not-callable
tf.reset_default_graph()
dumb = DumbModule(name="dumb_a")
ph_0 = tf.placeholder(dtype=tf.float32, shape=(1, 10,))
ph_1 = tf.placeholder(dtype=tf.float32, shape=(1, 10,))
dumb(DumbNamedTuple(ph_0, ph_1))
def check():
sonnet_collection = tf.get_default_graph().get_collection(
base_info.SONNET_COLLECTION_NAME)
connected_subgraph = sonnet_collection[0].connected_subgraphs[0]
self.assertTrue(
base_info._is_namedtuple(connected_subgraph.inputs["inputs"]))
self.assertTrue(base_info._is_namedtuple(connected_subgraph.outputs))
check()
_copy_default_graph()
check()
评论列表
文章目录