def _build_model(self, config, src_vocab, trg_vocab):
def convert(val):
if val.isdigit():
return int(val)
try:
return float(val)
except:
return val
model_config = config['Model']
kwargs = {k: convert(v) for k, v in model_config.items() if k != 'name'}
m = getattr(models, model_config['name'])(**kwargs)
model_path = os.path.join(self.save_dir, 'model.hdf')
# load
if os.path.exists(model_path):
chainer.serializers.load_hdf5(model_path, m)
xstoi = src_vocab.stoi
ystoi = trg_vocab.stoi
xbos = xstoi('<s>')
xeos = xstoi('</s>')
ybos = ystoi('<s>')
yeos = ystoi('</s>')
m.set_symbols(xbos, xeos, ybos, yeos)
m.name = model_config['name']
m.byte = self._load_binary_config(config['Training'], 'byte')
m.reverse_output = self._load_binary_config(
config['Training'], 'reverse_output')
if m.byte:
m.vocab = trg_vocab
return m
评论列表
文章目录