def set_up_predictor(self, nmt_model_path):
"""Initializes the predictor with the given NMT model. Code
following ``blocks.machine_translation.main``.
"""
self.src_vocab_size = self.config['src_vocab_size']
self.trgt_vocab_size = self.config['trg_vocab_size']
self.nmt_model = NMTModel(self.config)
self.nmt_model.set_up()
loader = LoadNMTUtils(nmt_model_path,
self.config['saveto'],
self.nmt_model.search_model)
loader.load_weights()
self.best_models = []
self.val_bleu_curve = []
self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \
if self.config['src_sparse_feat_map'] else FlatSparseFeatMap()
if self.config['trg_sparse_feat_map']:
logging.fatal("Cannot use bounded vocabulary predictor with "
"a target sparse feature map. Ignoring...")
self.search_algorithm = MyopticSearch(samples=self.nmt_model.samples)
self.search_algorithm.compile()
评论列表
文章目录