def _derivative_quadratic_form_factory(self, lhs, rhs):
def closure(left_factor, right_factor):
left_grad = left_factor.transpose(-1, -2).matmul(right_factor.matmul(rhs.transpose(-1, -2)))
right_grad = lhs.transpose(-1, -2).matmul(left_factor.transpose(-1, -2)).matmul(right_factor)
return left_grad, right_grad
return closure
评论列表
文章目录