test_baseline.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:nnmnkwii 作者: r9y9 项目源码 文件源码
def test_gmmmap_swap():
    static_dim = 2
    T = 10
    windows = _get_windows_set()[-1]

    np.random.seed(1234)
    src_mc = np.random.rand(T, static_dim * len(windows))
    tgt_mc = np.random.rand(T, static_dim * len(windows))

    # pseudo parallel data
    XY = np.concatenate((src_mc, tgt_mc), axis=-1)
    gmm = GaussianMixture(n_components=4)
    gmm.fit(XY)

    paramgen = MLPG(gmm, windows=windows, swap=False)
    swap_paramgen = MLPG(gmm, windows=windows, swap=True)

    mc_converted1 = paramgen.transform(src_mc)
    mc_converted2 = swap_paramgen.transform(tgt_mc)

    src_mc = src_mc[:, :static_dim]
    tgt_mc = tgt_mc[:, :static_dim]

    assert norm(tgt_mc - mc_converted1) < norm(src_mc - mc_converted1)
    assert norm(tgt_mc - mc_converted2) > norm(src_mc - mc_converted2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号