def save(self, dir_name):
dir_path = os.path.join(self._root_dir_path, dir_name)
if not os.path.exists(dir_path):
os.mkdir(dir_path)
others = []
for key, value in self.items():
if key.startswith('_'):
continue
if isinstance(value, (np.ndarray, list)):
np.save(os.path.join(dir_path, key + ".npy"), value)
elif isinstance(value, (chainer.Chain, chainer.ChainList)):
model_path = os.path.join(dir_path, "model.npz")
chainer.serializers.save_npz(model_path, value)
elif isinstance(value, chainer.Optimizer):
optimizer_path = os.path.join(dir_path, "optimizer.npz")
chainer.serializers.save_npz(optimizer_path, value)
else:
others.append("{}: {}".format(key, value))
with open(os.path.join(dir_path, "log.txt"), "a") as f:
text = "\n".join(others) + "\n"
f.write(text)
python类Optimizer()的实例源码
def set_shared_states(a, b):
assert isinstance(a, chainer.Optimizer)
assert hasattr(a, 'target'), 'Optimizer.setup must be called first'
for param_name, param in a.target.namedparams():
ensure_initialized_update_rule(param)
state = param.update_rule.state
for state_name, state_val in b[param_name].items():
s = state[state_name]
state[state_name] = np.frombuffer(
state_val,
dtype=s.dtype).reshape(s.shape)
def extract_states_as_shared_arrays(optimizer):
assert isinstance(optimizer, chainer.Optimizer)
assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first'
shared_arrays = {}
for param_name, param in optimizer.target.namedparams():
shared_arrays[param_name] = {}
ensure_initialized_update_rule(param)
state = param.update_rule.state
for state_name, state_val in state.items():
shared_arrays[param_name][
state_name] = mp.RawArray('f', state_val.ravel())
return shared_arrays
def as_shared_objects(obj):
if isinstance(obj, tuple):
return tuple(as_shared_objects(x) for x in obj)
elif isinstance(obj, chainer.Link):
return share_params_as_shared_arrays(obj)
elif isinstance(obj, chainer.Optimizer):
return share_states_as_shared_arrays(obj)
elif isinstance(obj, mp.sharedctypes.Synchronized):
return obj
else:
raise ValueError('')
def synchronize_to_shared_objects(obj, shared_memory):
if isinstance(obj, tuple):
return tuple(synchronize_to_shared_objects(o, s)
for o, s in zip(obj, shared_memory))
elif isinstance(obj, chainer.Link):
set_shared_params(obj, shared_memory)
return obj
elif isinstance(obj, chainer.Optimizer):
set_shared_states(obj, shared_memory)
return obj
elif isinstance(obj, mp.sharedctypes.Synchronized):
return shared_memory
else:
raise ValueError('')
def __init__(
self,
args,
loss_maker,
main_optimizer,
main_lossfun,
reinput_optimizer=None,
reinput_lossfun=None,
discriminator_optimizer=None,
discriminator_lossfun=None,
*_args, **kwargs
):
# type: (any, comicolorization.loss.LossMaker, any, typing.Callable[[typing.Dict], any], typing.List[chainer.Optimizer], typing.Callable[[int, typing.Dict], any], any, typing.Callable[[typing.Dict], any], *any, **any) -> None
optimizers = {'main': main_optimizer}
if reinput_optimizer is not None:
for i_reinput, optimizer in enumerate(reinput_optimizer):
optimizers['reinput{}'.format(i_reinput)] = optimizer
if discriminator_optimizer is not None:
optimizers['discriminator'] = discriminator_optimizer
super().__init__(optimizer=optimizers, *_args, **kwargs)
# chainer.reporter cannot work on some optimizer focus same model
if args.separate_backward_reinput and reinput_optimizer is None:
reinput_optimizer = [main_optimizer for _ in range(len(args.loss_blend_ratio_reinput))]
self.args = args
self.loss_maker = loss_maker
self.main_optimizer = main_optimizer
self.main_lossfun = main_lossfun
self.reinput_optimizer = reinput_optimizer
self.reinput_lossfun = reinput_lossfun
self.discriminator_optimizer = discriminator_optimizer
self.discriminator_lossfun = discriminator_lossfun
def set_shared_states(a, b):
assert isinstance(a, chainer.Optimizer)
assert hasattr(a, 'target'), 'Optimizer.setup must be called first'
for state_name, shared_state in b.items():
for param_name, param in shared_state.items():
old_param = a._states[state_name][param_name]
a._states[state_name][param_name] = np.frombuffer(
param,
dtype=old_param.dtype).reshape(old_param.shape)
def extract_states_as_shared_arrays(optimizer):
assert isinstance(optimizer, chainer.Optimizer)
assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first'
shared_arrays = {}
for state_name, state in optimizer._states.items():
shared_arrays[state_name] = {}
for param_name, param in state.items():
shared_arrays[state_name][
param_name] = mp.RawArray('f', param.ravel())
return shared_arrays