preprocess.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:plda 作者: RaviSoji 项目源码 文件源码
def get_principal_components(flattened_images, n_components='default',
                             default_pct_variance_explained=.96):
    """ Standardizes the data and gets the principal components.
    """
    for img in flattened_images:
        assert isinstance(img, np.ndarray)
        assert img.shape == flattened_images[-1].shape
        assert len(img.shape) == 1
    X = np.asarray(flattened_images)
    X -= X.mean(axis=0)  # Center all of the data around the origin.
    X /= np.std(X, axis=0)

    pca = PCA()
    pca.fit(X)

    if n_components == 'default':
        sorted_eig_vals = pca.explained_variance_
        cum_pct_variance = (sorted_eig_vals / sorted_eig_vals.sum()).cumsum()
        idxs = np.argwhere(cum_pct_variance >= default_pct_variance_explained)
        n_components = np.squeeze(idxs)[0]

    V = pca.components_[:n_components + 1, :].T
    principal_components = np.matmul(X, V)

    return principal_components
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号