def forward(self, add_batch, batch1, batch2): self.save_for_backward(batch1, batch2) output = self._get_output(add_batch) return torch.baddbmm(output, self.alpha, add_batch, self.beta, batch1, batch2)