def testCreatePhasesWithUnwrappedLoop(self):
# Test a preprocessing function with control flow.
#
# The loop represents
#
# i = 0
# while i < 10:
# i += 1
# x += 1
#
# We need to call an analyzer after the loop because only the transitive
# parents of analyzers are inspected by create_phases
def preprocessing_fn(inputs):
def _subtract_ten(x):
i = tf.constant(0)
c = lambda i, x: tf.less(i, 10)
b = lambda i, x: (tf.add(i, 1), tf.add(x, -1))
return tf.while_loop(c, b, [i, x])[1]
scaled_to_0_1 = mappers.scale_to_0_1(_subtract_ten(inputs['x']))
return {'x_scaled': scaled_to_0_1}
input_schema = sch.Schema({
'x': sch.ColumnSchema(tf.int32, [], sch.FixedColumnRepresentation())
})
graph, _, _ = impl_helper.run_preprocessing_fn(
preprocessing_fn, input_schema)
with self.assertRaisesRegexp(ValueError, 'Cycle detected'):
_ = impl_helper.create_phases(graph)
评论列表
文章目录