def testModuleInfo_sparsetensor(self):
# pylint: disable=not-callable
tf.reset_default_graph()
dumb = DumbModule(name="dumb_a")
sparse_tensor = tf.SparseTensor(
indices=tf.placeholder(dtype=tf.int64, shape=(10, 2,)),
values=tf.placeholder(dtype=tf.float32, shape=(10,)),
dense_shape=tf.placeholder(dtype=tf.int64, shape=(2,)))
dumb(sparse_tensor)
def check():
sonnet_collection = tf.get_default_graph().get_collection(
base_info.SONNET_COLLECTION_NAME)
connected_subgraph = sonnet_collection[0].connected_subgraphs[0]
self.assertIsInstance(
connected_subgraph.inputs["inputs"], tf.SparseTensor)
self.assertIsInstance(connected_subgraph.outputs, tf.SparseTensor)
check()
_copy_default_graph()
check()
评论列表
文章目录