def grad(self, inputs, output_gradients):
num_ins = len(inputs)
if num_ins == 3:
x, v, sorter = inputs
else:
x, v = inputs
x_grad = gradient._float_zeros_like(x)
v_grad = gradient._float_zeros_like(v)
if num_ins == 3:
return [x_grad, v_grad, disconnected_type()]
else:
return [x_grad, v_grad]
评论列表
文章目录