def grad(self, inputs, gout):
(cond, ift, iff) = inputs
(gz,) = gout
first_part = switch(cond, gz, 0.)
second_part = switch(cond, 0., gz)
out = self(cond, ift, iff)
if out.type.dtype in discrete_types:
first_part = 0.
second_part = 0.
# cond does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = cond.zeros_like().astype(theano.config.floatX)
return (condition_grad, first_part, second_part)
评论列表
文章目录