test_kde.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:ChainConsumer 作者: Samreay 项目源码 文件源码
def test_megkde_2d_basic():
    # Draw from normal, fit KDE, see if sampling from kde's pdf recovers norm
    np.random.seed(1)
    data = np.random.multivariate_normal([0, 1], [[1.0, 0.], [0., 0.75 ** 2]], size=10000)
    xs, ys = np.linspace(-4, 4, 50), np.linspace(-4, 4, 50)
    xx, yy = np.meshgrid(xs, ys, indexing='ij')
    samps = np.vstack((xx.flatten(), yy.flatten())).T
    zs = MegKDE(data).evaluate(samps).reshape(xx.shape)
    zs_x = zs.sum(axis=1)
    zs_y = zs.sum(axis=0)
    cs_x = zs_x.cumsum()
    cs_x /= cs_x[-1]
    cs_x[0] = 0
    cs_y = zs_y.cumsum()
    cs_y /= cs_y[-1]
    cs_y[0] = 0
    samps_x = interp1d(cs_x, xs)(np.random.uniform(size=10000))
    samps_y = interp1d(cs_y, ys)(np.random.uniform(size=10000))
    mu_x, std_x = norm.fit(samps_x)
    mu_y, std_y = norm.fit(samps_y)
    assert np.isclose(mu_x, 0, atol=0.1)
    assert np.isclose(std_x, 1.0, atol=0.1)
    assert np.isclose(mu_y, 1, atol=0.1)
    assert np.isclose(std_y, 0.75, atol=0.1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号