def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass
if 'int' in self.dtype:
return [ipt.zeros_like().astype(theano.config.floatX)
for ipt in inputs]
grads = []
for i, inp in enumerate(inputs):
grads.append(output_gradients[0][i])
return grads
评论列表
文章目录