def _gini(self, class_counts):
"""Calculate the Gini impurity.
If c(i) denotes the i-th class count and c = sum_i c(i) then
score = 1 - sum_i ( c(i) / c )^2
Args:
class_counts: A 2-D tensor of per-class counts, usually a slice or
gather from variables.node_sums.
Returns:
A 1-D tensor of the Gini impurities for each row in the input.
"""
smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
sums = math_ops.reduce_sum(smoothed, 1)
sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)
return 1.0 - sum_squares / (sums * sums)
评论列表
文章目录