def fit(self,X,y=None):
'''
?feature_names???????????????
X: ?????????DataFrame??Series?
y: ??????Series?
'''
if y is None:
raise Exception('y????')
dt=sklearn.tree.DecisionTreeClassifier(criterion='entropy',max_depth=self.max_depth,**self.kwds)
if len(X.shape)==1:
dt.fit(X.reshape((-1,1)),y)
cuts=getTreeSplits(dt)
if cuts is None:
# ?????????????????????
cuts=np.array([np.median(X)])
else:
cuts=dict()
if self.feature_names is None:
try:
feature_names=list(X.columns)
except:
feature_names=list(range(X.shape[1]))
else:
feature_names=self.feature_names
for feature in feature_names:
try:
x=X[:,feature]
except:
x=X[feature]
x=x.reshape((-1,1))
dt.fit(x,y)
cut=getTreeSplits(dt)
if cut is None:
cut=np.array([np.median(x)])
cuts[feature]=cut.copy()
self.cuts=copy.deepcopy(cuts)
return self
评论列表
文章目录