def forward(self, add_matrix, batch1, batch2): self.save_for_backward(batch1, batch2) output = self._get_output(add_matrix) return torch.addbmm(output, self.alpha, add_matrix, self.beta, batch1, batch2)