def grad(self, inp, grads):
x, y, inverse = inp
gz, = grads
# First, compute the gradient wrt the broadcasted x.
# If 'inverse' is False (0), apply the inverse of y on gz.
# Else, apply y on gz.
gx = permute_row_elements(gz, y, eq(inverse, 0))
# If x has been broadcasted along some axes, we need to sum
# the gradient over these axes, but keep the dimension (as
# broadcastable)
broadcasted_dims = [dim for dim in xrange(gz.type.ndim)
if x.type.broadcastable[dim] and
not gz.type.broadcastable[dim]]
gx = Sum(axis=broadcasted_dims)(gx)
# Sum(...) removed the dimensions in broadcasted_dims,
# so we need to put them back.
newdims = []
i = 0
for dim in xrange(gz.type.ndim):
if dim in broadcasted_dims:
newdims.append('x')
else:
newdims.append(i)
i += 1
gx = DimShuffle(gx.type.broadcastable, newdims)(gx)
assert gx.type.broadcastable == x.type.broadcastable
# if x is an integer type, then so is the output.
# this means f(x+eps) = f(x) so the gradient with respect
# to x is zero
if x.type.dtype.find('int') != -1:
gx = x.zeros_like()
# The elements of y and of inverse both affect the output,
# so they are connected to the output,
# and the transformation isn't defined if their values
# are non-integer, so the gradient with respect to them is
# undefined
return [gx, grad_undefined(self, 1, y),
grad_undefined(self, 1, inverse)]
评论列表
文章目录