def test_grid_data(self):
x, y = np.linspace(-3, 3, 200), np.linspace(-5, 5, 200)
xx, yy = np.meshgrid(x, y, indexing='ij')
xs, ys = xx.flatten(), yy.flatten()
chain = np.vstack((xs, ys)).T
pdf = (1 / (2 * np.pi)) * np.exp(-0.5 * (xs * xs + ys * ys / 4))
c = ChainConsumer()
c.add_chain(chain, parameters=['x', 'y'], weights=pdf, grid=True)
summary = c.analysis.get_summary()
x_sum = summary['x']
y_sum = summary['y']
expected_x = np.array([-1.0, 0.0, 1.0])
expected_y = np.array([-2.0, 0.0, 2.0])
threshold = 0.1
assert np.all(np.abs(expected_x - x_sum) < threshold)
assert np.all(np.abs(expected_y - y_sum) < threshold)
评论列表
文章目录