test_baseline.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:nnmnkwii 作者: r9y9 项目源码 文件源码
def test_diffvc():
    # MLPG is performed dimention by dimention, so static_dim 1 is enough, 2 just for in
    # case.
    static_dim = 2
    T = 10

    for windows in _get_windows_set():
        np.random.seed(1234)
        src_mc = np.random.rand(T, static_dim * len(windows))
        tgt_mc = np.random.rand(T, static_dim * len(windows))

        # pseudo parallel data
        XY = np.concatenate((src_mc, tgt_mc), axis=-1)
        gmm = GaussianMixture(n_components=4)
        gmm.fit(XY)

        paramgen = MLPG(gmm, windows=windows, diff=False)
        diff_paramgen = MLPG(gmm, windows=windows, diff=True)

        mc_converted1 = paramgen.transform(src_mc)
        mc_converted2 = diff_paramgen.transform(src_mc)

        assert mc_converted1.shape == (T, static_dim)
        assert mc_converted2.shape == (T, static_dim)

        src_mc = src_mc[:, :static_dim]
        tgt_mc = tgt_mc[:, :static_dim]
        assert norm(tgt_mc - mc_converted1) < norm(src_mc - mc_converted1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号