def forward(self, dense): if self.sparse.ndimension() == 3: return bdsmm(self.sparse, dense) else: return torch.dsmm(self.sparse, dense)