def grad(self, inp, grads):
# The strict sense mathematical gradient of the maximum function is
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient.
# @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
# gMax * dMax/dx + gArgMax * dArgMax/dx,
# gMax * dMax/daxis + gArgMax * dArgMax/daxis
# g_max has one less dimension than x, so you need to complete
# g_max to x's shape when axis=0 the broadcasting mechanism
# does it automatically
x, axis = inp
g_max, g_max_idx = grads
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType)
# if the op is totally disconnected, so are its inputs
if g_max_disconnected and g_max_idx_disconnected:
return [DisconnectedType()(), DisconnectedType()()]
axis_grad = grad_undefined(
self, 1, axis,
"argmax is not defined for non-integer axes so"
" argmax(x, axis+eps) is undefined")
# if the max is disconnected but the argmax is not,
# the gradient on its inputs is zero
if g_max_disconnected:
return [x.zeros_like(), axis_grad]
if NoneConst.equals(axis):
axis_ = list(range(x.ndim))
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
if NoneConst.equals(axis):
# We are taking the max/argmax over all dimensions.
axis = None
for i in xrange(x.ndim):
if axis is None or i in axis.data:
pattern.append('x')
else:
pattern.append(out_dim)
out_dim += 1
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
# Set the grad to the correct position.
g_x = eq(xmax_pad, x) * g_max_pad
return g_x, axis_grad
评论列表
文章目录