def log_normalize(log_vec, axis=0): axes = [slice(None)] * len(log_vec.shape) axes[axis] = np.newaxis return log_vec - logsumexp(log_vec, axis=axis)[axes]