def reset_parameters(self): std = math.sqrt(3 / self.in_features) nn.init.uniform(self.weight, -std, std) nn.init.uniform(self.bias, -std, std)