def __init__(self, nr_events, case_id_col, encoder_kwargs, cls_kwargs, cls_method="rf"):
self.case_id_col = case_id_col
self.nr_events = nr_events
self.encoder = SequenceEncoder(nr_events=nr_events, case_id_col=case_id_col, **encoder_kwargs)
if cls_method == "gbm":
self.cls = GradientBoostingRegressor(**cls_kwargs)
elif cls_method == "rf":
self.cls = RandomForestRegressor(**cls_kwargs)
else:
print("Classifier method not known")
评论列表
文章目录