test_async.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def test_share_states(self):

        model = L.Linear(2, 2)
        opt_a = optimizers.RMSprop()
        opt_a.setup(model)
        arrays = async.share_states_as_shared_arrays(opt_a)
        opt_b = optimizers.RMSprop()
        opt_b.setup(copy.deepcopy(model))
        # In Chainer v2, a model cannot be set up by two optimizers or more.

        opt_c = optimizers.RMSprop()
        opt_c.setup(copy.deepcopy(model))

        """
        Removed the tests by assert_different_pointers
        since they are trivial now.
        """

        async.set_shared_states(opt_b, arrays)
        async.set_shared_states(opt_c, arrays)

        def assert_same_pointers(a, b):
            a = a.target
            b = b.target
            for param_name, param_a in a.namedparams():
                param_b = dict(b.namedparams())[param_name]
                state_a = param_a.update_rule.state
                state_b = param_b.update_rule.state
                self.assertTrue(state_a)
                self.assertTrue(state_b)
                for state_name, state_val_a in state_a.items():
                    state_val_b = state_b[state_name]
                    self.assertTrue(isinstance(
                        state_val_a, np.ndarray))
                    self.assertTrue(isinstance(
                        state_val_b, np.ndarray))
                    self.assertEqual(state_val_a.ctypes.data,
                                     state_val_b.ctypes.data)

        assert_same_pointers(opt_a, opt_b)
        assert_same_pointers(opt_a, opt_c)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号