def _set_batch_non_diagonal_cpu(array, non_diag_val): batch_size, m, n = array.shape assert m == n rows, cols = np.tril_indices(n, -1) array[:, rows, cols] = non_diag_val