def __init__(self, nr_events, case_id_col, label_col, encoder_kwargs, cls_kwargs, cls_method="rf"):
self.case_id_col = case_id_col
self.label_col = label_col
self.encoder = SequenceEncoder(nr_events=nr_events, case_id_col=case_id_col, label_col=label_col,
**encoder_kwargs)
if cls_method == "gbm":
self.cls = GradientBoostingClassifier(**cls_kwargs)
elif cls_method == "rf":
self.cls = RandomForestClassifier(**cls_kwargs)
else:
print("Classifier method not known")
评论列表
文章目录