def dirac_delta(ni, no, k): n = min(ni, no) size = (n, n) + k repeats = (max(no // ni, 1), max(ni // no, 1)) + (1,) * len(k) return dirac(torch.Tensor(*size)).repeat(*repeats)