def forward(self, input): # detach weight to prevent gradients from changing weight when shared weight = self.weight if self.shared: weight = weight.detach() return F.linear(input, weight, self.bias)