xla_device_test.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testCond(self):
    """Tests that tf.cond works on XLA devices."""

    with session_lib.Session() as session:
      x = array_ops.placeholder(dtypes.float32)
      y = array_ops.placeholder(dtypes.float32)
      c = array_ops.placeholder(dtypes.bool)
      with ops.device("device:XLA_CPU:0"):
        z = x + 1.0
        w = control_flow_ops.cond(c, lambda: z, lambda: y)
        t = math_ops.add(z, w)

      result = session.run(t, {x: np.float32(2), y: np.float32(4), c: True})
      self.assertAllClose(result, np.float32(6), rtol=1e-3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号