def testCond(self):
"""Tests that compilation handles switch operators."""
with self.test_session() as session:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
c = array_ops.placeholder(dtypes.bool)
with jit_scope():
z = x + 1.0
w = control_flow_ops.cond(c, lambda: z, lambda: y)
t = math_ops.add(z, w)
# If JIT compilation chooses to cluster z and t, then execution will
# deadlock.
run_metadata = config_pb2.RunMetadata()
result = session.run(t, {x: np.float32(2),
y: np.float32(4),
c: True},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaLaunch(run_metadata))
self.assertAllClose(result, np.float32(6), rtol=1e-1)
jit_test.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录