def testJITVariableSeed(self):
"""Test that the stateful initializer is not marked for compilation.
XLA does not currently support seeded initialization and XLA initializers
therefore return different values than non-XLA counterparts. Here
we ensure that if we can disable JIT compilation for the initializers and
get the same variable values as if no JIT compilation happened.
"""
def create_ops():
with variable_scope.variable_scope(
"root",
initializer=init_ops.random_uniform_initializer(
-0.1, 0.1, seed=2)):
inputs = variable_scope.get_variable("var", (1,))
return inputs
_, v_false_1 = self.compute(False, create_ops)
_, v_false_2 = self.compute(False, create_ops)
_, v_true_1 = self.compute(enable_jit_nonstateful, create_ops)
_, v_true_2 = self.compute(enable_jit_nonstateful, create_ops)
self.assertAllClose(v_false_1, v_false_2)
self.assertAllClose(v_true_1, v_true_2)
self.assertAllClose(v_false_1, v_true_1)
jit_test.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录