mv_gaussian.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:SourceFilterContoursMelody 作者: juanjobosch 项目源码 文件源码
def fit_gaussians(x_train_boxcox, y_train):
    """ Fit class-dependent multivariate gaussians on the training set.

    Parameters
    ----------
    x_train_boxcox : np.array [n_samples, n_features_trans]
        Transformed training features.
    y_train : np.array [n_samples]
        Training labels.

    Returns
    -------
    rv_pos : multivariate normal
        multivariate normal for melody class
    rv_neg : multivariate normal
        multivariate normal for non-melody class
    """
    pos_idx = np.where(y_train == 1)[0]
    mu_pos = np.mean(x_train_boxcox[pos_idx, :], axis=0)
    cov_pos = np.cov(x_train_boxcox[pos_idx, :], rowvar=0)

    neg_idx = np.where(y_train == 0)[0]
    mu_neg = np.mean(x_train_boxcox[neg_idx, :], axis=0)
    cov_neg = np.cov(x_train_boxcox[neg_idx, :], rowvar=0)
    rv_pos = multivariate_normal(mean=mu_pos, cov=cov_pos, allow_singular=True)
    rv_neg = multivariate_normal(mean=mu_neg, cov=cov_neg, allow_singular=True)
    return rv_pos, rv_neg
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号