def _bucketize(instances, feature, schema, metadata):
"""Applies the bucketize transform to a numeric field.
"""
field = schema[feature.field]
if not field.numeric:
raise ValueError('A scale transform cannot be applied to non-numerical field "%s".' %
feature.field)
transform = feature.transform
boundaries = map(float, transform['boundaries'].split(','))
# TODO: Figure out how to use tf.case instead of this contrib op
from tensorflow.contrib.layers.python.ops.bucketization_op import bucketize
# Create a one-hot encoded tensor. The dimension of this tensor is the set of buckets defined
# by N boundaries == N + 1.
# A squeeze is needed to remove the extra dimension added to the shape.
value = instances[feature.field]
value = tf.squeeze(tf.one_hot(bucketize(value, boundaries, name='bucket'),
depth=len(boundaries) + 1, on_value=1.0, off_value=0.0,
name='one_hot'),
axis=1, name='bucketize')
value.set_shape((None, len(boundaries) + 1))
return value
评论列表
文章目录