test_async.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def test_shared_link_copy(self):
        head = L.Linear(2, 2)
        model_a = chainer.ChainList(head.copy(), L.Linear(2, 3))
        model_b = chainer.ChainList(head.copy(), L.Linear(2, 4))
        a_params = dict(model_a.namedparams())
        b_params = dict(model_b.namedparams())
        self.assertEqual(a_params['/0/W'].data.ctypes.data,
                         b_params['/0/W'].data.ctypes.data)
        self.assertEqual(a_params['/0/b'].data.ctypes.data,
                         b_params['/0/b'].data.ctypes.data)
        import copy
        model_a_copy = copy.deepcopy(model_a)
        model_b_copy = copy.deepcopy(model_b)
        a_copy_params = dict(model_a_copy.namedparams())
        b_copy_params = dict(model_b_copy.namedparams())
        # When A and B are separately deepcopied, head is no longer shared
        self.assertNotEqual(a_copy_params['/0/W'].data.ctypes.data,
                            b_copy_params['/0/W'].data.ctypes.data)
        self.assertNotEqual(a_copy_params['/0/b'].data.ctypes.data,
                            b_copy_params['/0/b'].data.ctypes.data)

        model_total_copy = copy.deepcopy(chainer.ChainList(model_a, model_b))
        model_a_copy = model_total_copy[0]
        model_b_copy = model_total_copy[1]
        a_copy_params = dict(model_a_copy.namedparams())
        b_copy_params = dict(model_b_copy.namedparams())
        # When ChainList(A, B) is deepcopied, head is still shared!
        self.assertEqual(a_copy_params['/0/W'].data.ctypes.data,
                         b_copy_params['/0/W'].data.ctypes.data)
        self.assertEqual(a_copy_params['/0/b'].data.ctypes.data,
                         b_copy_params['/0/b'].data.ctypes.data)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号