matrix.py 文件源码

python
阅读 51 收藏 0 点赞 0 评论 0

项目:neurotools 作者: michaelerule 项目源码 文件源码
def tril_elements(M):
    '''
    Somewhat like matlab's "diag" function, but for lower triangular matrices

    tril_elements(randn(D*(D+1)//2))
    '''
    if len(M.shape)==2:
        # M is a matrix
        if not M.shape[0]==M.shape[1]:
            raise ValueError("Extracting upper triangle elements supported only on square arrays")
        # Extract upper trianglular elements
        i = np.tril_indices(M.shape[0])
        return M[i]
    if len(M.shape)==1:
        # M is a vector
        # N(N+1)/2 = K
        # N(N+1) = 2K
        # NN+N = 2K
        # NN+N-2K=0
        # A x^2 + Bx + C
        # -1 +- sqrt(1-4*1*(-2K))
        # -----------------------
        #           2
        # 
        # (sqrt(1+8*K)-1)/2
        K = M.shape[0]
        N = (np.sqrt(1+8*K)-1)/2
        if N!=round(N):
            raise ValueError('Cannot pack %d elements into a square triangular matrix'%K)
        N = int(N)
        result = np.zeros((N,N))
        result[np.tril_indices(N)] = M
        return result
    raise ValueError("Must be 2D matrix or 1D vector")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号