def test_cat_feature(self, query, feature, categories=None):
"""
Find splits on a categorical feature
"""
if not categories:
categories = feature.vals
# map of a leaf's ID to all the values that lead to it
cat_ids = {}
for val in categories:
# test each value one after the other
query_cat = make_query(query, feature.name, val)
cat_id = self.predict(query_cat)
logging.log(DEBUG, '\t val {} got {}'.format(val, cat_id))
if cat_id in cat_ids:
cat_ids[cat_id].append(val)
else:
cat_ids[cat_id] = [val]
return cat_ids
评论列表
文章目录