def load_params(self, saveto):
try:
logger.info(" ...loading model parameters")
params_all = numpy.load(saveto)
params_this = self.get_params()
missing = set(params_this) - set(params_all)
for pname in params_this.keys():
if pname in params_all:
val = params_all[pname]
self._set_param_value(params_this[pname], val, pname)
elif self.num_decs > 1 and self.decoder.share_att and \
pname in self.decoder.shared_params_map:
val = params_all[self.decoder.shared_params_map[pname]]
self._set_param_value(params_this[pname], val, pname)
else:
logger.warning(
" Parameter does not exist: {}".format(pname))
logger.info(
" Number of params loaded: {}"
.format(len(params_this) - len(missing)))
except Exception as e:
logger.error(" Error {0}".format(str(e)))
评论列表
文章目录