utils.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def psgd_4(sgd, n_iter_per_job, n_jobs, X_train, y_train, coef, intercept):
    """
    Parallel SGD implementation using multithreading. All workers read coef and intercept from share memory,
    process them, and then overwrite them.

    Parameters
    ----------
    sgd: input SGDRegression() object
    n_iter_per_job: number of iterations per worker
    n_jobs: number of parallel processes to run
    X_train: train input data
    y_train: train target data
    coef: randomly initialized coefs stored in shared memory
    intercept: randomly initialized intercept stored in shared memory

    Returns
    -------
    sgd: the input SGDRegressor() object with updated coef_ and intercept_
    """
    sgds = [SGDRegressor(warm_start=True, n_iter=1)
            for _ in range(n_jobs)]

    sgds = Parallel(n_jobs=n_jobs, backend="threading")(
        delayed(psgd_method_2) (s, n_iter_per_job, coef, intercept, X_train, y_train)
        for s in sgds)

    sgd.coef_ = np.array([x.coef_ for x in sgds]).mean(axis=0)
    sgd.intercept_ = np.array([x.intercept_ for x in sgds]).mean(axis=0)
    return sgd
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号