def test_dsmm(self):
def test_shape(di, dj, dk):
x = self._gen_sparse(2, 20, [di, dj])[0]
y = self.randn(dj, dk)
res = torch.dsmm(x, y)
expected = torch.mm(x.to_dense(), y)
self.assertEqual(res, expected)
test_shape(7, 5, 3)
test_shape(1000, 100, 100)
test_shape(3000, 64, 300)
python类dsmm()的实例源码
def test_dsmm(self):
def test_shape(di, dj, dk):
x = self._gen_sparse(2, 20, [di, dj])[0]
y = self.randn(dj, dk)
res = torch.dsmm(x, y)
expected = torch.mm(x.to_dense(), y)
self.assertEqual(res, expected)
test_shape(7, 5, 3)
test_shape(1000, 100, 100)
test_shape(3000, 64, 300)
def test_dsmm(self):
def test_shape(di, dj, dk):
x = self._gen_sparse(2, 20, [di, dj])[0]
y = self.randn(dj, dk)
res = torch.dsmm(x, y)
expected = torch.mm(x.to_dense(), y)
self.assertEqual(res, expected)
test_shape(7, 5, 3)
test_shape(1000, 100, 100)
test_shape(3000, 64, 300)
def test_interpolation():
x = torch.linspace(0.01, 1, 100)
grid = torch.linspace(-0.05, 1.05, 50)
J, C = Interpolation().interpolate(grid, x)
W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid))
test_func_grid = grid.pow(2)
test_func_x = x.pow(2)
interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze()
assert all(torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
def _derivative_quadratic_form_factory(self, *args):
def closure(left_vectors, right_vectors):
if left_vectors.ndimension() == 1:
left_factor = left_vectors.unsqueeze(0)
right_factor = right_vectors.unsqueeze(0)
else:
left_factor = left_vectors
right_factor = right_vectors
if len(args) == 1:
columns, = args
return kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor),
elif len(args) == 3:
columns, W_left, W_right = args
left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()
res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
return res, None, None
elif len(args) == 4:
columns, W_left, W_right, added_diag, = args
diag_grad = columns.new(len(added_diag)).zero_()
diag_grad[0] = (left_factor * right_factor).sum()
left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()
res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
return res, None, None, diag_grad
return closure
def forward(self, dense):
if self.sparse.ndimension() == 3:
return bdsmm(self.sparse, dense)
else:
return torch.dsmm(self.sparse, dense)
def backward(self, grad_output):
if self.sparse.ndimension() == 3:
return bdsmm(self.sparse.transpose(1, 2), grad_output)
else:
return torch.dsmm(self.sparse.t(), grad_output)
def test_dsmm(self):
def test_shape(di, dj, dk):
x = self._gen_sparse(2, 20, [di, dj])[0]
y = self.randn(dj, dk)
res = torch.dsmm(x, y)
expected = torch.mm(self.safeToDense(x), y)
self.assertEqual(res, expected)
test_shape(7, 5, 3)
test_shape(1000, 100, 100)
test_shape(3000, 64, 300)
def kp_interpolated_toeplitz_matmul(toeplitz_columns, tensor, interp_left=None, interp_right=None, noise_diag=None):
"""
Given an interpolated matrix interp_left * T_1 \otimes ... \otimes T_d * interp_right, plus possibly an additional
diagonal component s*I, compute a product with some tensor or matrix tensor, where T_i is
symmetric Toeplitz matrices.
Args:
- toeplitz_columns (d x m matrix) - columns of d toeplitz matrix T_i with
length n_i
- interp_left (sparse matrix nxm) - Left interpolation matrix
- interp_right (sparse matrix pxm) - Right interpolation matrix
- tensor (matrix p x k) - Vector (k=1) or matrix (k>1) to multiply WKW with
- noise_diag (tensor p) - If not none, add (s*I)tensor to WKW at the end.
Returns:
- tensor
"""
output_dims = tensor.ndimension()
noise_term = None
if output_dims == 1:
tensor = tensor.unsqueeze(1)
if noise_diag is not None:
noise_term = noise_diag.unsqueeze(1).expand_as(tensor) * tensor
if interp_left is not None:
# Get interp_{r}^{T} tensor
interp_right_tensor = torch.dsmm(interp_right.t(), tensor)
# Get (T interp_{r}^{T}) tensor
rhs = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, interp_right_tensor)
# Get (interp_{l} T interp_{r}^{T})tensor
output = torch.dsmm(interp_left, rhs)
else:
output = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, tensor)
if noise_term is not None:
# Get (interp_{l} T interp_{r}^{T} + \sigma^{2}I)tensor
output = output + noise_term
if output_dims == 1:
output = output.squeeze(1)
return output