test_inference_plda.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:plda 作者: RaviSoji 项目源码 文件源码
def gen_A(self, V, n_classes, n_dims, return_S_b=False):
        """ A = [B][inv(? ** .5)][Q^T] and assumes same number of data
             in each class v. """
        B = np.random.randint(-100, 100, (n_dims, n_dims)).astype(float)
        big_V = np.matmul(V.T, V)  # V is now a scatter matrix.
        vals, vecs = eig(big_V)
        A = B / np.sqrt(vals.real)
        A = np.matmul(A, vecs.T)

        D = np.matmul(np.matmul(vecs.T, big_V), vecs)
        assert np.allclose(D, np.diag(vals))

        if return_S_b is True:
            S_b = 1 /n_classes * np.matmul(np.matmul(A, big_V), A.T)
            x = np.matmul(A, V.T).T

            S_b_empirical = 1 / n_classes * np.matmul(x.T, x)
            assert np.allclose(S_b, S_b_empirical)

            return A, S_b
        else:
            return A
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号