blocks_nmt.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:sgnmt 作者: ucam-smt 项目源码 文件源码
def set_up_predictor(self, nmt_model_path):
        """Initializes the predictor with the given NMT model. Code 
        following ``blocks.machine_translation.main``. 
        """
        self.src_vocab_size = self.config['src_vocab_size']
        self.trgt_vocab_size = self.config['trg_vocab_size']
        self.nmt_model = NMTModel(self.config)
        self.nmt_model.set_up()
        loader = LoadNMTUtils(nmt_model_path,
                              self.config['saveto'],
                              self.nmt_model.search_model)
        loader.load_weights()

        self.best_models = []
        self.val_bleu_curve = []
        self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \
                if self.config['src_sparse_feat_map'] else FlatSparseFeatMap()
        if self.config['trg_sparse_feat_map']:
            logging.fatal("Cannot use bounded vocabulary predictor with "
                          "a target sparse feature map. Ignoring...")
        self.search_algorithm = MyopticSearch(samples=self.nmt_model.samples)
        self.search_algorithm.compile()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号