logistic_regression.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:prml 作者: Yevgnen 项目源码 文件源码
def logistic_regression(x, t, w, eps=1e-2, max_iter=int(1e3)):
    N = x.shape[1]
    Phi = np.vstack([np.ones(N), phi(x)]).T

    for k in range(max_iter):
        y = expit(Phi.dot(w))
        R = np.diag(np.ones(N) * (y * (1 - y)))
        H = Phi.T.dot(R).dot(Phi)
        g = Phi.T.dot(y - t)

        w_new = w - linalg.solve(H, g)

        diff = linalg.norm(w_new - w) / linalg.norm(w)
        if (diff < eps):
            break

        w = w_new
        print('{0:5d} {1:10.6f}'.format(k, diff))

    return w
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号