test_batch_base.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:megamix 作者: 14thibea 项目源码 文件源码
def test_log_normal_matrix_full():
    n_points, n_components, n_features = 10,5,2

    points = np.random.randn(n_points,n_features)
    means = np.random.randn(n_components,n_features)
    cov = generate_covariance_matrices_full(n_components,n_features)

    # Beginnig of the test
    log_det_cov = np.log(np.linalg.det(cov))
    precisions = np.linalg.inv(cov)
    log_prob = np.empty((n_points,n_components))
    for i in range(n_components):
        diff = points - means[i]
        y = np.dot(diff,np.dot(precisions[i],diff.T))
        log_prob[:,i] = np.diagonal(y)

    expected_log_normal_matrix = -0.5 * (n_features * np.log(2*np.pi) +
                                         log_prob + log_det_cov)

    predected_log_normal_matrix = _log_normal_matrix(points,means,cov,'full')

    assert_almost_equal(expected_log_normal_matrix,predected_log_normal_matrix)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号