def test_shared_link_copy(self):
head = L.Linear(2, 2)
model_a = chainer.ChainList(head.copy(), L.Linear(2, 3))
model_b = chainer.ChainList(head.copy(), L.Linear(2, 4))
a_params = dict(model_a.namedparams())
b_params = dict(model_b.namedparams())
self.assertEqual(a_params['/0/W'].data.ctypes.data,
b_params['/0/W'].data.ctypes.data)
self.assertEqual(a_params['/0/b'].data.ctypes.data,
b_params['/0/b'].data.ctypes.data)
import copy
model_a_copy = copy.deepcopy(model_a)
model_b_copy = copy.deepcopy(model_b)
a_copy_params = dict(model_a_copy.namedparams())
b_copy_params = dict(model_b_copy.namedparams())
# When A and B are separately deepcopied, head is no longer shared
self.assertNotEqual(a_copy_params['/0/W'].data.ctypes.data,
b_copy_params['/0/W'].data.ctypes.data)
self.assertNotEqual(a_copy_params['/0/b'].data.ctypes.data,
b_copy_params['/0/b'].data.ctypes.data)
model_total_copy = copy.deepcopy(chainer.ChainList(model_a, model_b))
model_a_copy = model_total_copy[0]
model_b_copy = model_total_copy[1]
a_copy_params = dict(model_a_copy.namedparams())
b_copy_params = dict(model_b_copy.namedparams())
# When ChainList(A, B) is deepcopied, head is still shared!
self.assertEqual(a_copy_params['/0/W'].data.ctypes.data,
b_copy_params['/0/W'].data.ctypes.data)
self.assertEqual(a_copy_params['/0/b'].data.ctypes.data,
b_copy_params['/0/b'].data.ctypes.data)
评论列表
文章目录