def save_params(self, weights_file, catched=False):
"""Save the model's params."""
with open(weights_file, "w") as f:
if catched:
if self.catched_params != []:
params_vl = self.catched_params
else:
raise ValueError(
"You asked to save catched params," +
"but you didn't catch any!!!!!!!")
else:
params_vl = [param.get_value() for param in self.params]
ft_extractor = False
if self.ft_extractor is not None:
ft_extractor = True
stuff = {"layers_infos": self.layers_infos,
"params_vl": params_vl,
"tag": self.tag,
"dropout": self.dropout,
"ft_extractor": ft_extractor,
"dic_keys": self.dic_keys,
"config_arch": self.config_arch,
"crop_size": self.crop_size}
pkl.dump(stuff, f, protocol=pkl.HIGHEST_PROTOCOL)
评论列表
文章目录