summary.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tensorboard 作者: tensorflow 项目源码 文件源码
def _buckets(data, bucket_count=None):
  """Create a TensorFlow op to group data into histogram buckets.

  Arguments:
    data: A `Tensor` of any shape. Must be castable to `float64`.
    bucket_count: Optional positive `int` or scalar `int32` `Tensor`.
  Returns:
    A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is
    a triple `[left_edge, right_edge, count]` for a single bucket.
    The value of `k` is either `bucket_count` or `1` or `0`.
  """
  if bucket_count is None:
    bucket_count = DEFAULT_BUCKET_COUNT
  with tf.name_scope('buckets', values=[data, bucket_count]), \
       tf.control_dependencies([tf.assert_scalar(bucket_count),
                                tf.assert_type(bucket_count, tf.int32)]):
    data = tf.reshape(data, shape=[-1])  # flatten
    data = tf.cast(data, tf.float64)
    is_empty = tf.equal(tf.size(data), 0)

    def when_empty():
      return tf.constant([], shape=(0, 3), dtype=tf.float64)

    def when_nonempty():
      min_ = tf.reduce_min(data)
      max_ = tf.reduce_max(data)
      range_ = max_ - min_
      is_singular = tf.equal(range_, 0)

      def when_nonsingular():
        bucket_width = range_ / tf.cast(bucket_count, tf.float64)
        offsets = data - min_
        bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
                                 dtype=tf.int32)
        clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
        one_hots = tf.one_hot(clamped_indices, depth=bucket_count)
        bucket_counts = tf.cast(tf.reduce_sum(one_hots, axis=0),
                                dtype=tf.float64)
        edges = tf.lin_space(min_, max_, bucket_count + 1)
        left_edges = edges[:-1]
        right_edges = edges[1:]
        return tf.transpose(tf.stack(
            [left_edges, right_edges, bucket_counts]))

      def when_singular():
        center = min_
        bucket_starts = tf.stack([center - 0.5])
        bucket_ends = tf.stack([center + 0.5])
        bucket_counts = tf.stack([tf.cast(tf.size(data), tf.float64)])
        return tf.transpose(
            tf.stack([bucket_starts, bucket_ends, bucket_counts]))

      return tf.cond(is_singular, when_singular, when_nonsingular)

    return tf.cond(is_empty, when_empty, when_nonempty)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号