def backward(self, grad_output): if self.sparse.ndimension() == 3: return bdsmm(self.sparse.transpose(1, 2), grad_output) else: return torch.dsmm(self.sparse.t(), grad_output)