def grad(self, inp, grads):
s, = inp
dt, = grads
if s.type.dtype in float_dtypes:
assert dt.type.dtype in float_dtypes
return [scalar_from_tensor(dt)]
# If the input dtype is an integer, then so is the output dtype,
# and the "zero" gradient can be represented in that int dtype.
# Currently, theano.grad insists that the dtype of the returned
# gradient has a float dtype, so we use floatX.
if s.type.dtype in discrete_dtypes:
return [s.zeros_like().astype(theano.config.floatX)]
raise NotImplementedError("grad not implemented for complex dtypes")
评论列表
文章目录