def test_share_states(self):
model = L.Linear(2, 2)
opt_a = optimizers.RMSprop()
opt_a.setup(model)
arrays = async.share_states_as_shared_arrays(opt_a)
opt_b = optimizers.RMSprop()
opt_b.setup(copy.deepcopy(model))
# In Chainer v2, a model cannot be set up by two optimizers or more.
opt_c = optimizers.RMSprop()
opt_c.setup(copy.deepcopy(model))
"""
Removed the tests by assert_different_pointers
since they are trivial now.
"""
async.set_shared_states(opt_b, arrays)
async.set_shared_states(opt_c, arrays)
def assert_same_pointers(a, b):
a = a.target
b = b.target
for param_name, param_a in a.namedparams():
param_b = dict(b.namedparams())[param_name]
state_a = param_a.update_rule.state
state_b = param_b.update_rule.state
self.assertTrue(state_a)
self.assertTrue(state_b)
for state_name, state_val_a in state_a.items():
state_val_b = state_b[state_name]
self.assertTrue(isinstance(
state_val_a, np.ndarray))
self.assertTrue(isinstance(
state_val_b, np.ndarray))
self.assertEqual(state_val_a.ctypes.data,
state_val_b.ctypes.data)
assert_same_pointers(opt_a, opt_b)
assert_same_pointers(opt_a, opt_c)
评论列表
文章目录