tools.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:caltech-machine-learning 作者: zhiyanfoo 项目源码 文件源码
def svm(x, y):
    """
    classification SVM

    Minimize
    1/2 * w^T w
    subject to
    y_n (w^T x_n + b) >= 1
    """
    weights_total = len(x[0])
    I_n = np.identity(weights_total-1)
    P_int =  np.vstack(([np.zeros(weights_total-1)], I_n))
    zeros = np.array([np.zeros(weights_total)]).T
    P = np.hstack((zeros, P_int))
    q = np.zeros(weights_total)
    G = -1 * vec_to_dia(y).dot(x)
    h = -1 * np.ones(len(y))
    matrix_arg = [ matrix(x) for x in [P,q,G,h] ]
    sol = solvers.qp(*matrix_arg)
    return np.array(sol['x']).flatten()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号