utils.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def psgd_1(sgd, n_iter_per_job, n_jobs, X_train, y_train):
    """
    Parallel SGD implementation using multiprocessing. All workers sync once after running SGD independently for
    n_iter_per_job iterations.

    Parameters
    ----------
    sgd: input SGDRegression() object
    n_iter_per_job: number of iterations per worker
    n_jobs: number of parallel processes to run
    X_train: train input data
    y_train: train target data

    Returns
    -------
    sgd: the input SGDRegressor() object with updated coef_ and intercept_
    """

    sgds = Parallel(n_jobs=n_jobs)(
        delayed(psgd_method_1)(s, X_train, y_train)
        for s in [SGDRegressor(n_iter=n_iter_per_job) for _ in range(n_jobs)])
    sgd.coef_ = np.array([x.coef_ for x in sgds]).mean(axis=0)
    sgd.intercept_ = np.array([x.intercept_ for x in sgds]).mean(axis=0)
    return sgd
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号