def _get_batch_diagonal_cpu(array): batch_size, m, n = array.shape assert m == n rows, cols = np.diag_indices(n) return array[:, rows, cols]