def inv_matmul_factory(matmul_closure_factory=_default_matmul_closure_factor,
derivative_quadratic_form_factory=_default_derivative_quadratic_form_factory):
class InvMatmul(Function):
def __init__(self, *args):
self.args = args
def forward(self, *args):
closure_args = self.args + args[:-1]
rhs = args[-1]
res = LinearCG().solve(matmul_closure_factory(*closure_args), rhs)
self.save_for_backward(*(list(args) + [res]))
return res
def backward(self, grad_output):
if derivative_quadratic_form_factory is None:
raise NotImplementedError
args = self.saved_tensors[:-2]
closure_args = self.args + args
res = self.saved_tensors[-1]
arg_grads = [None] * len(args)
rhs_grad = None
# input_1 gradient
if any(self.needs_input_grad[:-1]):
lhs_matrix_grad = LinearCG().solve(matmul_closure_factory(*closure_args), grad_output)
lhs_matrix_grad = lhs_matrix_grad.mul_(-1)
if res.ndimension() == 1:
res = res.unsqueeze(1)
if lhs_matrix_grad.ndimension() == 1:
lhs_matrix_grad = lhs_matrix_grad.unsqueeze(1)
arg_grads = list(derivative_quadratic_form_factory(*args)(lhs_matrix_grad.t(), res.t()))
# input_2 gradient
if self.needs_input_grad[-1]:
rhs_grad = LinearCG().solve(matmul_closure_factory(*closure_args), grad_output)
return tuple(arg_grads + [rhs_grad])
return InvMatmul
评论列表
文章目录