def all_reduce_params(sent_shared_params, rec_buffers, average_cnt = 1):
from mpi4py import MPI
mpi_communicator = MPI.COMM_WORLD
commu_time = 0.0
gpu2cpu_cp_time = 0.0
for (sent_model, rec_model) in zip(sent_shared_params, rec_buffers):
cp_start = time.time()
model_val = sent_model.get_value()
gpu2cpu_cp_time += time.time() - cp_start
commu_start = time.time()
mpi_communicator.Allreduce([model_val, MPI.FLOAT], [rec_model, MPI.FLOAT], op=MPI.SUM)
commu_time += time.time() - commu_start
if average_cnt != 1: #try to avoid dividing since it is very cost
rec_model = rec_model / average_cnt
cp_start = time.time()
sent_model.set_value(rec_model)
gpu2cpu_cp_time += time.time() - cp_start
return commu_time, gpu2cpu_cp_time
评论列表
文章目录