def extract_states_as_shared_arrays(optimizer):
assert isinstance(optimizer, chainer.Optimizer)
assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first'
shared_arrays = {}
for param_name, param in optimizer.target.namedparams():
shared_arrays[param_name] = {}
ensure_initialized_update_rule(param)
state = param.update_rule.state
for state_name, state_val in state.items():
shared_arrays[param_name][
state_name] = mp.RawArray('f', state_val.ravel())
return shared_arrays
评论列表
文章目录