def _build_model(self, config, src_vocab, trg_vocab):
def convert(val):
if val.isdigit():
return int(val)
try:
return float(val)
except:
return val
model_config = config['Model']
kwargs = {k: convert(v) for k, v in model_config.items() if k != 'name'}
m = getattr(models, model_config['name'])(**kwargs)
model_path = os.path.join(self.save_dir, 'model.hdf')
# load
if os.path.exists(model_path):
chainer.serializers.load_hdf5(model_path, m)
xstoi = src_vocab.stoi
ystoi = trg_vocab.stoi
xbos = xstoi('<s>')
xeos = xstoi('</s>')
ybos = ystoi('<s>')
yeos = ystoi('</s>')
m.set_symbols(xbos, xeos, ybos, yeos)
m.name = model_config['name']
m.byte = self._load_binary_config(config['Training'], 'byte')
m.reverse_output = self._load_binary_config(
config['Training'], 'reverse_output')
if m.byte:
m.vocab = trg_vocab
return m
python类serializers()的实例源码
def save(self):
save_dir = self.save_dir
m = self.model.copy()
m.name = self.model.name
m.to_cpu()
model_path = os.path.join(save_dir, 'model.hdf')
chainer.serializers.save_hdf5(model_path, m)
with open(os.path.join(save_dir, "vocab.pkl"), "wb") as f:
pickle.dump((self.src_vcb, self.trg_vcb), f)
def load_params(prefix, mdl, opt):
logger = logging.getLogger(__name__)
logger.info('Loading model/optimizer parameters')
chainer.serializers.load_npz(prefix + '.mdl', mdl)
chainer.serializers.load_npz(prefix + '.opt', opt)
def save_params(prefix, mdl, opt):
logger = logging.getLogger(__name__)
logger.info('Saving model/optimizer parameters')
chainer.serializers.save_npz(prefix + '.mdl', mdl)
chainer.serializers.save_npz(prefix + '.opt', opt)