env_cache_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:imperative 作者: yaroslavvb 项目源码 文件源码
def testSum1CacheGpu(self):
    if not tf.test.is_built_with_cuda():
      return True
    if not self.haveGpu0():
      return True

    env = imperative.Env(tf)
    with env.g.device("cpu:0"):
      val1 = env.numpy_to_itensor([1, 2, 3])
      val2 = env.numpy_to_itensor([4, 5, 6])
      val3 = env.numpy_to_itensor([4, 5, 6], dtype=tf.float64)
      try:
        out1 = env.sum1(val1)
      except:
        import pdb;
        pdb.post_mortem()
      self.assertTrue(is_graph_changed(env))
      out2 = env.sum1(val2)
      self.assertFalse(is_graph_changed(env))
      out3 = env.sum1(val3)
      self.assertTrue(is_graph_changed(env))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号