def testRunPreprocessingFn(self):
schema = self.toSchema({
'dense_1': tf.FixedLenFeature((), tf.float32),
'dense_2': tf.FixedLenFeature((1, 2), tf.int64),
'var_len': tf.VarLenFeature(tf.string),
'sparse': tf.SparseFeature('ix', 'val', tf.float32, 100)
})
def preprocessing_fn(inputs):
return {
'dense_out': mappers.scale_to_0_1(inputs['dense_1']),
'sparse_out': tf.sparse_reshape(inputs['sparse'], (1, 10)),
}
_, inputs, outputs = impl_helper.run_preprocessing_fn(
preprocessing_fn, schema)
# Verify that the input placeholders have the correct types.
expected_dtype_and_shape = {
'dense_1': (tf.float32, tf.TensorShape([None])),
'dense_2': (tf.int64, tf.TensorShape([None, 1, 2])),
'var_len': (tf.string, tf.TensorShape([None, None])),
'sparse': (tf.float32, tf.TensorShape([None, None])),
'dense_out': (tf.float32, tf.TensorShape([None])),
'sparse_out': (tf.float32, tf.TensorShape([None, None])),
}
for key, tensor in itertools.chain(six.iteritems(inputs),
six.iteritems(outputs)):
dtype, shape = expected_dtype_and_shape[key]
self.assertEqual(tensor.dtype, dtype)
tensor.get_shape().assert_is_compatible_with(shape)
评论列表
文章目录