def test_megkde_2d_basic():
# Draw from normal, fit KDE, see if sampling from kde's pdf recovers norm
np.random.seed(1)
data = np.random.multivariate_normal([0, 1], [[1.0, 0.], [0., 0.75 ** 2]], size=10000)
xs, ys = np.linspace(-4, 4, 50), np.linspace(-4, 4, 50)
xx, yy = np.meshgrid(xs, ys, indexing='ij')
samps = np.vstack((xx.flatten(), yy.flatten())).T
zs = MegKDE(data).evaluate(samps).reshape(xx.shape)
zs_x = zs.sum(axis=1)
zs_y = zs.sum(axis=0)
cs_x = zs_x.cumsum()
cs_x /= cs_x[-1]
cs_x[0] = 0
cs_y = zs_y.cumsum()
cs_y /= cs_y[-1]
cs_y[0] = 0
samps_x = interp1d(cs_x, xs)(np.random.uniform(size=10000))
samps_y = interp1d(cs_y, ys)(np.random.uniform(size=10000))
mu_x, std_x = norm.fit(samps_x)
mu_y, std_y = norm.fit(samps_y)
assert np.isclose(mu_x, 0, atol=0.1)
assert np.isclose(std_x, 1.0, atol=0.1)
assert np.isclose(mu_y, 1, atol=0.1)
assert np.isclose(std_y, 0.75, atol=0.1)
评论列表
文章目录