def forward(ctx, input, diagonal_idx=0): ctx.diagonal_idx = diagonal_idx return input.diag(ctx.diagonal_idx)