test_dream.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:PyDREAM 作者: LoLab-VU 项目源码 文件源码
def test_history_recording_multidim_model(self):
        """Test that history in memory matches with that recorded for test multi-dimensional model."""
        self.param, self.like = multidmodel()
        model = Model(self.like, self.param)
        dream = Dream(model=model, model_name='test_history_recording')
        history_arr = mp.Array('d', [0]*4*dream.total_var_dimension*3)
        n = mp.Value('i', 0)
        nchains = mp.Value('i', 3)
        pydream.Dream_shared_vars.history = history_arr
        pydream.Dream_shared_vars.count = n
        pydream.Dream_shared_vars.nchains = nchains
        test_history = np.array([[[1, 2, 3, 4], [3, 4, 5, 6], [5, 6, 7, 8]], [[7, 8, 9, 10], [9, 12, 18, 20], [11, 14, 18, 8]], [[13, 14, 18, 4], [15, 17, 11, 8], [17, 28, 50, 4]], [[19, 21, 1, 18], [21, 19, 19, 11], [23, 4, 3, 2]]])
        for chainpoint in test_history:
            for point in chainpoint:
                dream.record_history(nseedchains=0, ndimensions=dream.total_var_dimension, q_new=point, len_history=len(history_arr))
        history_arr_np = np.frombuffer(pydream.Dream_shared_vars.history.get_obj())
        history_arr_np_reshaped = history_arr_np.reshape(np.shape(test_history))
        self.assertIs(np.array_equal(history_arr_np_reshaped, test_history), True)
        remove('test_history_recording_DREAM_chain_history.npy')
        remove('test_history_recording_DREAM_chain_adapted_crossoverprob.npy')
        remove('test_history_recording_DREAM_chain_adapted_gammalevelprob.npy')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号