def save(self, dir_name):
dir_path = os.path.join(self._root_dir_path, dir_name)
if not os.path.exists(dir_path):
os.mkdir(dir_path)
others = []
for key, value in self.items():
if key.startswith('_'):
continue
if isinstance(value, (np.ndarray, list)):
np.save(os.path.join(dir_path, key + ".npy"), value)
elif isinstance(value, (chainer.Chain, chainer.ChainList)):
model_path = os.path.join(dir_path, "model.npz")
chainer.serializers.save_npz(model_path, value)
elif isinstance(value, chainer.Optimizer):
optimizer_path = os.path.join(dir_path, "optimizer.npz")
chainer.serializers.save_npz(optimizer_path, value)
else:
others.append("{}: {}".format(key, value))
with open(os.path.join(dir_path, "log.txt"), "a") as f:
text = "\n".join(others) + "\n"
f.write(text)
评论列表
文章目录