def gen_q_sqrt(rng, D_out, *shape): q_sqrt = np.array([np.tril(rng.randn(*shape)) for _ in range(D_out)]) return np.transpose(q_sqrt, [1, 2, 0])