def grad(self, inp, grads):
x, = inp
out = self(*inp)
if out.dtype.find('int') != -1:
return [x.zeros_like(dtype=theano.config.floatX)]
gz, = grads
gz = as_tensor_variable(gz)
axis = self.axis
if axis is None:
axis = list(range(x.type.ndim))
if axis == ():
return gz,
new_dims = []
i = 0
for j, _ in enumerate(x.type.broadcastable):
if j in axis:
new_dims.append('x')
else:
new_dims.append(i)
i += 1
ds_op = DimShuffle(gz.type.broadcastable, new_dims)
gx = Elemwise(scalar.second)(x, ds_op(gz))
return [gx]
评论列表
文章目录