def grad(self, inputs, gout):
(x, y) = inputs
(gz,) = gout
if gz.type in complex_types:
# max is currently defined for complex_types,
# but the gradient for complex is not.
raise NotImplementedError()
output = self(x, y)
if output.type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)]
gx = eq(output, x) * gz
gy = eq(output, y) * gz
return (gx, gy)
评论列表
文章目录