def backward(self, inputs, grad_outputs):
assert self.comm.size == len(grad_outputs)
xp = cuda.get_array_module(*inputs)
with cuda.get_device_from_array(*inputs):
gys = tuple([gy for gy in grad_outputs])
gx = self.comm.alltoall(gys)
gx = [xp.array(_gx) for _gx in gx]
return tuple(gx)
评论列表
文章目录