def mpi_average_gradients(self,arr,num_replicas=None):
if num_replicas == None:
num_replicas = self.num_workers
if self.task_index >= num_replicas:
arr *= 0.0
arr_global = np.empty_like(arr)
self.comm.Allreduce(arr,arr_global,op=MPI.SUM)
arr_global /= num_replicas
return arr_global
评论列表
文章目录