models.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:ESL-Model 作者: littlezz 项目源码 文件源码
def train(self):
        X = self.train_x
        y = self.train_y
        # include intercept
        beta = np.zeros((self.p+1, 1))

        iter_times = 0
        while True:
            e_X = np.exp(X @ beta)
            # N x 1
            self.P = e_X / (1 + e_X)
            # W is a vector
            self.W = (self.P * (1 - self.P)).flatten()
            # X.T * W equal (X.T @ diagflat(W)).diagonal()
            beta = beta + self.math.pinv((X.T * self.W) @ X) @ X.T @ (y - self.P)

            iter_times += 1
            if iter_times > self.max_iter:
                break

        self.beta_hat = beta
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号