mv_gaussian.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:motif 作者: rabitt 项目源码 文件源码
def fit(self, X, Y):
        """ Fit class-dependent multivariate gaussians on the training set.

        Parameters
        ----------
        x_train_boxcox : np.array [n_samples, n_features_trans]
            Transformed training features.
        y_train : np.array [n_samples]
            Training labels.

        Returns
        -------
        rv_pos : multivariate normal
            multivariate normal for melody class
        rv_neg : multivariate normal
            multivariate normal for non-melody class
        """
        X_boxcox = self._fit_boxcox(X)
        pos_idx = np.where(Y == 1)[0]
        mu_pos = np.mean(X_boxcox[pos_idx, :], axis=0)
        cov_pos = np.cov(X_boxcox[pos_idx, :], rowvar=0)

        neg_idx = np.where(Y == 0)[0]
        mu_neg = np.mean(X_boxcox[neg_idx, :], axis=0)
        cov_neg = np.cov(X_boxcox[neg_idx, :], rowvar=0)
        rv_pos = multivariate_normal(
            mean=mu_pos, cov=cov_pos, allow_singular=True
        )
        rv_neg = multivariate_normal(
            mean=mu_neg, cov=cov_neg, allow_singular=True
        )
        self.rv_pos = rv_pos
        self.rv_neg = rv_neg
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号