def average_impurity(self):
"""Constructs a TF graph for evaluating the average leaf impurity of a tree.
If in regression mode, this is the leaf variance. If in classification mode,
this is the gini impurity.
Returns:
The last op in the graph.
"""
children = array_ops.squeeze(array_ops.slice(
self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
is_leaf = math_ops.equal(constants.LEAF_NODE, children)
leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
squeeze_dims=[1]))
counts = array_ops.gather(self.variables.node_sums, leaves)
gini = self._weighted_gini(counts)
# Guard against step 1, when there often are no leaves yet.
def impurity():
return gini
# Since average impurity can be used for loss, when there's no data just
# return a big number so that loss always decreases.
def big():
return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
return control_flow_ops.cond(math_ops.greater(
array_ops.shape(leaves)[0], 0), impurity, big)
评论列表
文章目录