boosting.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Adaboost 作者: shzygmyx 项目源码 文件源码
def __init__(self, X, y, estimator = DecisionTreeClassifier, itern = 20, mode = "sign"):
        self.X = X
        self.y = y.copy()
        self.estimator = estimator
        self.mode = mode
        self.itern = itern
        self.estimators = [] # estimators produced by boosting algorithm
        self.alphas = np.array([])  # weights of each boost estimator
        self.m = self.X.shape[0] # number of samples
        self.w = np.array([1/self.m] * self.m) # weights of samples
        self.cls_list = [] # list used to store classes' name and numbers
        self.cls0 = y[0]
        for i in range(self.m):
            if y[i] not in self.cls_list:
                self.cls_list.append(y[i])
            if y[i] == self.cls0:
                self.y[i] = 1
            else:
                self.y[i] = -1
        if len(self.cls_list) != 2:
            raise TypeError(
            '''This Adaboost only support two-class problem, for multiclass 
            problem, please use AdaboostMH.''')
        self.train()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号