def _derivative_quadratic_form_factory(self, *args):
def closure(left_vectors, right_vectors):
if left_vectors.ndimension() == 1:
left_factor = left_vectors.unsqueeze(0)
right_factor = right_vectors.unsqueeze(0)
else:
left_factor = left_vectors
right_factor = right_vectors
if len(args) == 1:
columns, = args
return kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor),
elif len(args) == 3:
columns, W_left, W_right = args
left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()
res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
return res, None, None
elif len(args) == 4:
columns, W_left, W_right, added_diag, = args
diag_grad = columns.new(len(added_diag)).zero_()
diag_grad[0] = (left_factor * right_factor).sum()
left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()
res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
return res, None, None, diag_grad
return closure
kronecker_product_lazy_variable.py 文件源码
python
阅读 42
收藏 0
点赞 0
评论 0
评论列表
文章目录