def backward(ctx, grad_output):
L, = ctx.saved_variables
if ctx.upper:
L = L.t()
grad_output = grad_output.t()
# make sure not to double-count variation, since
# only half of output matrix is unique
Lbar = grad_output.tril()
P = Potrf.phi(torch.mm(L.t(), Lbar))
S = torch.gesv(P + P.t(), L.t())[0]
S = torch.gesv(S.t(), L.t())[0]
S = Potrf.phi(S)
return S, None
评论列表
文章目录