def kp_interpolated_toeplitz_matmul(toeplitz_columns, tensor, interp_left=None, interp_right=None, noise_diag=None):
"""
Given an interpolated matrix interp_left * T_1 \otimes ... \otimes T_d * interp_right, plus possibly an additional
diagonal component s*I, compute a product with some tensor or matrix tensor, where T_i is
symmetric Toeplitz matrices.
Args:
- toeplitz_columns (d x m matrix) - columns of d toeplitz matrix T_i with
length n_i
- interp_left (sparse matrix nxm) - Left interpolation matrix
- interp_right (sparse matrix pxm) - Right interpolation matrix
- tensor (matrix p x k) - Vector (k=1) or matrix (k>1) to multiply WKW with
- noise_diag (tensor p) - If not none, add (s*I)tensor to WKW at the end.
Returns:
- tensor
"""
output_dims = tensor.ndimension()
noise_term = None
if output_dims == 1:
tensor = tensor.unsqueeze(1)
if noise_diag is not None:
noise_term = noise_diag.unsqueeze(1).expand_as(tensor) * tensor
if interp_left is not None:
# Get interp_{r}^{T} tensor
interp_right_tensor = torch.dsmm(interp_right.t(), tensor)
# Get (T interp_{r}^{T}) tensor
rhs = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, interp_right_tensor)
# Get (interp_{l} T interp_{r}^{T})tensor
output = torch.dsmm(interp_left, rhs)
else:
output = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, tensor)
if noise_term is not None:
# Get (interp_{l} T interp_{r}^{T} + \sigma^{2}I)tensor
output = output + noise_term
if output_dims == 1:
output = output.squeeze(1)
return output
评论列表
文章目录