kde.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:ChainConsumer 作者: Samreay 项目源码 文件源码
def __init__(self, train, weights=None, truncation=3.0, nmin=4, factor=1.0):
        """
        Parameters
        ----------
        train : np.ndarray
            The training data set. Should be a 1D array of samples or a 2D array of shape (n_samples, n_dim).
        weights : np.ndarray, optional
            An array of weights. If not specified, equal weights are assumed.
        truncation : float, optional
            The maximum deviation (in sigma) to use points in the KDE
        nmin : int, optional
            The minimum number of points required to estimate the density
        factor : float, optional
            Send bandwidth to this factor of the data estimate
        """

        self.truncation = truncation
        self.nmin = nmin
        self.train = train
        if len(train.shape) == 1:
            train = np.atleast_2d(train).T
        self.num_points, self.num_dim = train.shape
        if weights is None:
            weights = np.ones(self.num_points)
        self.weights = weights

        self.mean = np.average(train, weights=weights, axis=0)
        dx = train - self.mean
        cov = np.atleast_2d(np.cov(dx.T, aweights=weights))
        self.A = np.linalg.cholesky(np.linalg.inv(cov))  # The sphere-ifying transform

        self.d = np.dot(dx, self.A)  # Sphere-ified data
        self.tree = spatial.cKDTree(self.d)  # kD tree of data

        self.sigma = 2.0 * factor * np.power(self.num_points, -1. / (4 + self.num_dim))  # Starting sigma (bw) of Gauss
        self.sigma_fact = -0.5 / (self.sigma * self.sigma)

        # Cant get normed probs to work atm, turning off for now as I don't need normed pdfs for contours
        # self.norm = np.product(np.diagonal(self.A)) * (2 * np.pi) ** (-0.5 * self.num_dim)  # prob norm
        # self.scaling = np.power(self.norm * self.sigma, -self.num_dim)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号