def _do_bprop(self, env, idx):
if idx == 0:
x = self.inputs[0].get_value()
i = self.inputs[1].get_value()
g = self.outputs[0].get_grad()
xshape = x.shape
axis, ndim, alen = self._axis, len(xshape), xshape[self._axis]
i_hat = i.reshpae(-1)
g_hat = g.reshape(-1)
r_hat = np.zeros((i.shape[0], alen), dtype=x.dtype)
r_hat[np.arange(r_hat.shape[0]), i_hat] = g_hat
r = r_hat.reshape(xshape[:axis] + xshape[axis+1:] + (alen, ))
r = np.moveaxis(r, -1, axis)
return r
else:
return 0
评论列表
文章目录