def forward(self, input): output = F.linear(input, self.weight, self.bias) return sparsify_grad(output, self.k, self.simplified)