utils.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def parallel_sgd(pool, sgd, n_iter, n_jobs, n_sync, data):
    """
    High level parallelization of SGDRegressor.

    Parameters
    ----------
    pool: multiprocessor pool to use for this parallelization
    sgd: SGDRegressor instance whose coef and intercept need to be updated
    n_iter: number of iterations per worker
    n_jobs: number of parallel workers
    n_sync: number of synchronization steps. Syncs are spread evenly through out the iterations
    data: list of (X, y) data for the workers. This list should have n_jobs elements

    Returns
    -------
    sgd: SGDRegressor instance with updated coef and intercept
    """
    # eta = sgd.eta0*n_jobs
    eta = sgd.eta0
    n_iter_sync = n_iter/n_sync  # Iterations per model between syncs
    sgds = [SGDRegressor(warm_start=True, n_iter=n_iter_sync, eta0=eta)
            for _ in range(n_jobs)]

    for _ in range(n_sync):
        args = zip(sgds, data)
        sgds = pool.map(psgd_method, args)
        coef = np.array([x.coef_ for x in sgds]).mean(axis=0)
        intercept = np.array([x.intercept_ for x in sgds]).mean(axis=0)
        for s in sgds:
            s.coef_ = coef
            s.intercept_ = intercept


    sgd.coef_ = coef
    sgd.intercept_ = intercept

    return sgd
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号