builder.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:lencon 作者: kiyukuta 项目源码 文件源码
def _build_model(self, config, src_vocab, trg_vocab):

        def convert(val):
            if val.isdigit():
                return int(val)
            try:
                return float(val)
            except:
                return val
        model_config = config['Model']

        kwargs = {k: convert(v) for k, v in model_config.items() if k != 'name'}
        m = getattr(models, model_config['name'])(**kwargs)

        model_path = os.path.join(self.save_dir, 'model.hdf')
        # load
        if os.path.exists(model_path):
            chainer.serializers.load_hdf5(model_path, m)

        xstoi = src_vocab.stoi
        ystoi = trg_vocab.stoi
        xbos = xstoi('<s>')
        xeos = xstoi('</s>')
        ybos = ystoi('<s>')
        yeos = ystoi('</s>')
        m.set_symbols(xbos, xeos, ybos, yeos)

        m.name = model_config['name']
        m.byte = self._load_binary_config(config['Training'], 'byte')
        m.reverse_output = self._load_binary_config(
            config['Training'], 'reverse_output')
        if m.byte:
            m.vocab = trg_vocab
        return m
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号