def tril_elements(M):
'''
Somewhat like matlab's "diag" function, but for lower triangular matrices
tril_elements(randn(D*(D+1)//2))
'''
if len(M.shape)==2:
# M is a matrix
if not M.shape[0]==M.shape[1]:
raise ValueError("Extracting upper triangle elements supported only on square arrays")
# Extract upper trianglular elements
i = np.tril_indices(M.shape[0])
return M[i]
if len(M.shape)==1:
# M is a vector
# N(N+1)/2 = K
# N(N+1) = 2K
# NN+N = 2K
# NN+N-2K=0
# A x^2 + Bx + C
# -1 +- sqrt(1-4*1*(-2K))
# -----------------------
# 2
#
# (sqrt(1+8*K)-1)/2
K = M.shape[0]
N = (np.sqrt(1+8*K)-1)/2
if N!=round(N):
raise ValueError('Cannot pack %d elements into a square triangular matrix'%K)
N = int(N)
result = np.zeros((N,N))
result[np.tril_indices(N)] = M
return result
raise ValueError("Must be 2D matrix or 1D vector")
评论列表
文章目录