def testCond(self):
"""Tests that tf.cond works on XLA devices."""
with session_lib.Session() as session:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
c = array_ops.placeholder(dtypes.bool)
with ops.device("device:XLA_CPU:0"):
z = x + 1.0
w = control_flow_ops.cond(c, lambda: z, lambda: y)
t = math_ops.add(z, w)
result = session.run(t, {x: np.float32(2), y: np.float32(4), c: True})
self.assertAllClose(result, np.float32(6), rtol=1e-3)
xla_device_test.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录