jit_test.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testCond(self):
    """Tests that compilation handles switch operators."""

    with self.test_session() as session:
      x = array_ops.placeholder(dtypes.float32)
      y = array_ops.placeholder(dtypes.float32)
      c = array_ops.placeholder(dtypes.bool)
      with jit_scope():
        z = x + 1.0
        w = control_flow_ops.cond(c, lambda: z, lambda: y)
        t = math_ops.add(z, w)

      # If JIT compilation chooses to cluster z and t, then execution will
      # deadlock.

      run_metadata = config_pb2.RunMetadata()
      result = session.run(t, {x: np.float32(2),
                               y: np.float32(4),
                               c: True},
                           run_metadata=run_metadata,
                           options=config_pb2.RunOptions(
                               trace_level=config_pb2.RunOptions.FULL_TRACE))
      self.assert_(MetadataHasXlaLaunch(run_metadata))
      self.assertAllClose(result, np.float32(6), rtol=1e-1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号