def backward(self, x, gy):
xp = cuda.get_array_module(*x)
gx = xp.empty_like(x[0])
if self.axis is None:
gx[:] = gy[0]
else:
gy = gy[0]
actual_axis = []
for axis in self.axis:
if axis < 0:
axis = len(gx.shape) + axis
actual_axis.append(axis)
for axis in sorted(actual_axis):
gy = xp.expand_dims(gy, axis=axis)
gx[:] = gy
return gx,
评论列表
文章目录