test_async.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def test_shared_link(self):
        """Check interprocess parameter sharing works if models share links"""

        head = L.Linear(2, 2)
        model_a = chainer.ChainList(head.copy(), L.Linear(2, 3))
        model_b = chainer.ChainList(head.copy(), L.Linear(2, 4))

        a_arrays = async.extract_params_as_shared_arrays(
            chainer.ChainList(model_a))
        b_arrays = async.extract_params_as_shared_arrays(
            chainer.ChainList(model_b))

        print(('model_a shared_arrays', a_arrays))
        print(('model_b shared_arrays', b_arrays))

        head = L.Linear(2, 2)
        model_a = chainer.ChainList(head.copy(), L.Linear(2, 3))
        model_b = chainer.ChainList(head.copy(), L.Linear(2, 4))

        async.set_shared_params(model_a, a_arrays)
        async.set_shared_params(model_b, b_arrays)

        print('model_a replaced')
        a_params = dict(model_a.namedparams())
        for param_name, param in list(a_params.items()):
            print((param_name, param.data.ctypes.data))

        print('model_b replaced')
        b_params = dict(model_b.namedparams())
        for param_name, param in list(b_params.items()):
            print((param_name, param.data.ctypes.data))

        # Pointers to head parameters must be the same
        self.assertEqual(a_params['/0/W'].data.ctypes.data,
                         b_params['/0/W'].data.ctypes.data)
        self.assertEqual(a_params['/0/b'].data.ctypes.data,
                         b_params['/0/b'].data.ctypes.data)

        # Pointers to tail parameters must be different
        self.assertNotEqual(a_params['/1/W'].data.ctypes.data,
                            b_params['/1/W'].data.ctypes.data)
        self.assertNotEqual(a_params['/1/b'].data.ctypes.data,
                            b_params['/1/b'].data.ctypes.data)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号