def grad(self, inputs, gout):
(x, y) = inputs
(gz,) = gout
if x.type in complex_types:
raise NotImplementedError()
# If the output of this op is discrete, then it
# it is locally flat everywhere, so the gradient
# through it is 0.
# This is different from it not being connected
# to the output; x/y is still a function of x
# and y; it's just a step function.
if all(a.dtype in discrete_types for a in (x, y)):
return [x.zeros_like(), y.zeros_like()]
first_part = gz / y
if y.type in complex_types:
raise NotImplementedError()
second_part = -(gz * x) / (y * y)
return first_part, second_part
评论列表
文章目录