def grad(self, inputs, output_grads):
gout, = output_grads
s = inputs[1]
gf = curfft_op(gout, s)
# Multiply the last dimension of the gradient by 2, they represent
# both positive and negative frequencies, except the first
# and last elements (for even transforms) which are unique.
idx = [slice(None)] * (gf.ndim - 2) \
+ [slice(1, (s[-1] // 2) + (s[-1] % 2))] + [slice(None)]
gf = T.set_subtensor(gf[idx], gf[idx] * 2)
return [gf, DisconnectedType()()]
评论列表
文章目录