utils.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:DL4NMT_Theano 作者: fyabc 项目源码 文件源码
def all_reduce_params(sent_shared_params, rec_buffers, average_cnt = 1):
    from mpi4py import MPI
    mpi_communicator = MPI.COMM_WORLD
    commu_time = 0.0
    gpu2cpu_cp_time = 0.0
    for (sent_model, rec_model) in zip(sent_shared_params, rec_buffers):
        cp_start = time.time()
        model_val = sent_model.get_value()
        gpu2cpu_cp_time += time.time() - cp_start

        commu_start = time.time()
        mpi_communicator.Allreduce([model_val, MPI.FLOAT], [rec_model, MPI.FLOAT], op=MPI.SUM)
        commu_time += time.time() - commu_start
        if average_cnt != 1: #try to avoid dividing since it is very cost
            rec_model = rec_model / average_cnt

        cp_start = time.time()
        sent_model.set_value(rec_model)
        gpu2cpu_cp_time += time.time() - cp_start
    return commu_time,  gpu2cpu_cp_time
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号