gam.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:dstk 作者: jotterbach 项目源码 文件源码
def _recurse(tree, feature_vec):

    assert isinstance(tree, Tree), "Tree is not a sklearn Tree"

    break_idx = 0
    node_id = 0

    if not isinstance(feature_vec, list):
        feature_vec = list([feature_vec])

    leaf_node_id = 0
    lower = np.NINF
    upper = np.PINF

    while (node_id != TREE_LEAF) & (tree.feature[node_id] != TREE_UNDEFINED):
        feature_idx = tree.feature[node_id]
        threshold = tree.threshold[node_id]

        if np.float32(feature_vec[feature_idx]) <= threshold:
            upper = threshold
            if (tree.children_left[node_id] != TREE_LEAF) and (tree.children_left[node_id] != TREE_UNDEFINED):
                leaf_node_id = tree.children_left[node_id]
            node_id = tree.children_left[node_id]
        else:
            lower = threshold
            if (tree.children_right[node_id] == TREE_LEAF) and (tree.children_right[node_id] != TREE_UNDEFINED):
                leaf_node_id = tree.children_right[node_id]
            node_id = tree.children_right[node_id]

        break_idx += 1
        if break_idx > 2 * tree.node_count:
            raise RuntimeError("infinite recursion!")

    return leaf_node_id, lower, upper
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号