def testCreatePhasesWithLoop(self):
# Test a preprocessing function with control flow.
#
# The loop represents
#
# i = 0
# while i < 10:
# i += 1
# x += 1
#
# To get an error in the case where apply_function is not called, we have
# to call an analyzer first (see testCreatePhasesWithUnwrappedLoop). So
# we also do so here.
def preprocessing_fn(inputs):
def _subtract_ten(x):
i = tf.constant(0)
c = lambda i, x: tf.less(i, 10)
b = lambda i, x: (tf.add(i, 1), tf.add(x, -1))
return tf.while_loop(c, b, [i, x])[1]
scaled_to_0_1 = mappers.scale_to_0_1(
api.apply_function(_subtract_ten, inputs['x']))
return {'x_scaled': scaled_to_0_1}
input_schema = sch.Schema({
'x': sch.ColumnSchema(tf.int32, [], sch.FixedColumnRepresentation())
})
graph, _, _ = impl_helper.run_preprocessing_fn(
preprocessing_fn, input_schema)
phases = impl_helper.create_phases(graph)
self.assertEqual(len(phases), 1)
self.assertEqual(len(phases[0].analyzers), 2)
评论列表
文章目录