def forward(self, add_matrix, matrix1, matrix2): self.save_for_backward(matrix1, matrix2) output = self._get_output(add_matrix) return torch.addmm(output, self.alpha, add_matrix, self.beta, matrix1, matrix2)