def _set_batch_diagonal_cpu(array, diag_val): batch_size, m, n = array.shape assert m == n rows, cols = np.diag_indices(n) array[:, rows, cols] = diag_val