base_info_test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testModuleInfo_sparsetensor(self):
    # pylint: disable=not-callable
    tf.reset_default_graph()
    dumb = DumbModule(name="dumb_a")
    sparse_tensor = tf.SparseTensor(
        indices=tf.placeholder(dtype=tf.int64, shape=(10, 2,)),
        values=tf.placeholder(dtype=tf.float32, shape=(10,)),
        dense_shape=tf.placeholder(dtype=tf.int64, shape=(2,)))
    dumb(sparse_tensor)
    def check():
      sonnet_collection = tf.get_default_graph().get_collection(
          base_info.SONNET_COLLECTION_NAME)
      connected_subgraph = sonnet_collection[0].connected_subgraphs[0]
      self.assertIsInstance(
          connected_subgraph.inputs["inputs"], tf.SparseTensor)
      self.assertIsInstance(connected_subgraph.outputs, tf.SparseTensor)
    check()
    _copy_default_graph()
    check()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号