def _weighted_gini(self, class_counts):
"""Our split score is the Gini impurity times the number of examples.
If c(i) denotes the i-th class count and c = sum_i c(i) then
score = c * (1 - sum_i ( c(i) / c )^2 )
= c - sum_i c(i)^2 / c
Args:
class_counts: A 2-D tensor of per-class counts, usually a slice or
gather from variables.node_sums.
Returns:
A 1-D tensor of the Gini impurities for each row in the input.
"""
smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
sums = math_ops.reduce_sum(smoothed, 1)
sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)
return sums - sum_squares / sums
评论列表
文章目录