_transforms.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tensorfx 作者: TensorLab 项目源码 文件源码
def _bucketize(instances, feature, schema, metadata):
  """Applies the bucketize transform to a numeric field.
  """
  field = schema[feature.field]
  if not field.numeric:
    raise ValueError('A scale transform cannot be applied to non-numerical field "%s".' %
                     feature.field)

  transform = feature.transform
  boundaries = map(float, transform['boundaries'].split(','))

  # TODO: Figure out how to use tf.case instead of this contrib op
  from tensorflow.contrib.layers.python.ops.bucketization_op import bucketize

  # Create a one-hot encoded tensor. The dimension of this tensor is the set of buckets defined
  # by N boundaries == N + 1.
  # A squeeze is needed to remove the extra dimension added to the shape.
  value = instances[feature.field]

  value = tf.squeeze(tf.one_hot(bucketize(value, boundaries, name='bucket'),
                                depth=len(boundaries) + 1, on_value=1.0, off_value=0.0,
                                name='one_hot'),
                     axis=1, name='bucketize')
  value.set_shape((None, len(boundaries) + 1))
  return value
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号