base_info_test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testModuleInfo_recursion(self):
    # pylint: disable=not-callable
    tf.reset_default_graph()
    dumb = DumbModule(name="dumb_a", no_nest=True)
    ph_0 = tf.placeholder(dtype=tf.float32, shape=(1, 10,))
    val = {"one": ph_0, "self": None}
    val["self"] = val
    dumb(val)
    def check(check_type):
      sonnet_collection = tf.get_default_graph().get_collection(
          base_info.SONNET_COLLECTION_NAME)
      connected_subgraph = sonnet_collection[0].connected_subgraphs[0]
      self.assertIsInstance(connected_subgraph.inputs["inputs"]["one"],
                            tf.Tensor)
      self.assertIsInstance(
          connected_subgraph.inputs["inputs"]["self"], check_type)
      self.assertIsInstance(connected_subgraph.outputs["one"], tf.Tensor)
      self.assertIsInstance(connected_subgraph.outputs["self"], check_type)
    check(dict)
    _copy_default_graph()
    check(base_info._UnserializableObject)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号