test_analysis.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:ChainConsumer 作者: Samreay 项目源码 文件源码
def test_grid_data(self):
        x, y = np.linspace(-3, 3, 200), np.linspace(-5, 5, 200)
        xx, yy = np.meshgrid(x, y, indexing='ij')
        xs, ys = xx.flatten(), yy.flatten()
        chain = np.vstack((xs, ys)).T
        pdf = (1 / (2 * np.pi)) * np.exp(-0.5 * (xs * xs + ys * ys / 4))
        c = ChainConsumer()
        c.add_chain(chain, parameters=['x', 'y'], weights=pdf, grid=True)
        summary = c.analysis.get_summary()
        x_sum = summary['x']
        y_sum = summary['y']
        expected_x = np.array([-1.0, 0.0, 1.0])
        expected_y = np.array([-2.0, 0.0, 2.0])
        threshold = 0.1
        assert np.all(np.abs(expected_x - x_sum) < threshold)
        assert np.all(np.abs(expected_y - y_sum) < threshold)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号