def grad(self, axis_and_tensors, grads):
""" The gradient wrt a join op is a `Split`, used to partition
the gradient along the `axis` which was used for joining.
"""
gz, = grads
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
rval = [grad_undefined(self, 0, axis)]
dtypes = [as_tensor_variable(x).type.dtype for x in tensors]
out_dtype = scal.upcast(*dtypes)
if 'float' in out_dtype or 'complex' in out_dtype:
# assume that this is differentiable
split = Split(len(tensors))
split_gz = split(gz, axis, stack([shape(x)[axis]
for x in tensors]))
# If there is only one split, it might not be in a list.
if not isinstance(split_gz, list):
split_gz = [split_gz]
# Split.make_node isn't always able to infer the right
# broadcast. As the grad need to keep the information,
# read it if needed.
split_gz = [patternbroadcast(g, t.broadcastable)
for t, g in zip(tensors, split_gz)]
rval = rval + split_gz
else:
# the output has integer type, so the gradient through it
# is 0
rval = rval + [tensor.zeros_like(dtype=config.floatX)
for tensor in tensors]
return rval
评论列表
文章目录