def get_stats(self, session):
num_nodes = self.variables.end_of_tree.eval(session=session) - 1
num_leaves = array_ops.where(
math_ops.equal(array_ops.squeeze(array_ops.slice(
self.variables.tree, [0, 0], [-1, 1])), constants.LEAF_NODE)
).eval(session=session).shape[0]
return TreeStats(num_nodes, num_leaves)
评论列表
文章目录