def backward(self, inputs, grad_outputs):
xp = cuda.get_array_module(*inputs)
with cuda.get_device_from_array(*inputs):
grad = self.comm.recv(self.peer_rank, self.peer_tag)
if isinstance(grad, tuple):
return tuple([xp.array(gy) for gy in grad])
else:
return xp.array(grad),
评论列表
文章目录