def backward_gpu(self, inputs, grad_output):
w, = inputs
grad_rotated_w, = grad_output
xp = cuda.get_array_module(w)
# Gradient must be initialized with zeros,
# because the kernel accumulates the gradient instead of overwriting it
grad_w = xp.zeros_like(w)
grad_index_group_func_kernel(
grad_output=grad_rotated_w,
T=self.T,
U=self.U,
V=self.V,
grad_input=grad_w
)
return grad_w,
评论列表
文章目录