def _recurse(tree, feature_vec):
assert isinstance(tree, Tree), "Tree is not a sklearn Tree"
break_idx = 0
node_id = 0
if not isinstance(feature_vec, list):
feature_vec = list([feature_vec])
leaf_node_id = 0
lower = np.NINF
upper = np.PINF
while (node_id != TREE_LEAF) & (tree.feature[node_id] != TREE_UNDEFINED):
feature_idx = tree.feature[node_id]
threshold = tree.threshold[node_id]
if np.float32(feature_vec[feature_idx]) <= threshold:
upper = threshold
if (tree.children_left[node_id] != TREE_LEAF) and (tree.children_left[node_id] != TREE_UNDEFINED):
leaf_node_id = tree.children_left[node_id]
node_id = tree.children_left[node_id]
else:
lower = threshold
if (tree.children_right[node_id] == TREE_LEAF) and (tree.children_right[node_id] != TREE_UNDEFINED):
leaf_node_id = tree.children_right[node_id]
node_id = tree.children_right[node_id]
break_idx += 1
if break_idx > 2 * tree.node_count:
raise RuntimeError("infinite recursion!")
return leaf_node_id, lower, upper
评论列表
文章目录