def _buckets(data, bucket_count=None):
"""Create a TensorFlow op to group data into histogram buckets.
Arguments:
data: A `Tensor` of any shape. Must be castable to `float64`.
bucket_count: Optional positive `int` or scalar `int32` `Tensor`.
Returns:
A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is
a triple `[left_edge, right_edge, count]` for a single bucket.
The value of `k` is either `bucket_count` or `1` or `0`.
"""
if bucket_count is None:
bucket_count = DEFAULT_BUCKET_COUNT
with tf.name_scope('buckets', values=[data, bucket_count]), \
tf.control_dependencies([tf.assert_scalar(bucket_count),
tf.assert_type(bucket_count, tf.int32)]):
data = tf.reshape(data, shape=[-1]) # flatten
data = tf.cast(data, tf.float64)
is_empty = tf.equal(tf.size(data), 0)
def when_empty():
return tf.constant([], shape=(0, 3), dtype=tf.float64)
def when_nonempty():
min_ = tf.reduce_min(data)
max_ = tf.reduce_max(data)
range_ = max_ - min_
is_singular = tf.equal(range_, 0)
def when_nonsingular():
bucket_width = range_ / tf.cast(bucket_count, tf.float64)
offsets = data - min_
bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
dtype=tf.int32)
clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
one_hots = tf.one_hot(clamped_indices, depth=bucket_count)
bucket_counts = tf.cast(tf.reduce_sum(one_hots, axis=0),
dtype=tf.float64)
edges = tf.lin_space(min_, max_, bucket_count + 1)
left_edges = edges[:-1]
right_edges = edges[1:]
return tf.transpose(tf.stack(
[left_edges, right_edges, bucket_counts]))
def when_singular():
center = min_
bucket_starts = tf.stack([center - 0.5])
bucket_ends = tf.stack([center + 0.5])
bucket_counts = tf.stack([tf.cast(tf.size(data), tf.float64)])
return tf.transpose(
tf.stack([bucket_starts, bucket_ends, bucket_counts]))
return tf.cond(is_singular, when_singular, when_nonsingular)
return tf.cond(is_empty, when_empty, when_nonempty)
评论列表
文章目录