NMTEngine.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:NeuralMT 作者: hlt-mt 项目源码 文件源码
def tune(self, src_batch, trg_batch, epochs):
        self._ensure_model_loaded()

        if self._tuner is None:
            self._tuner = NMTEngineTrainer(self._model, self._optim, self._src_dict, self._trg_dict,
                                           model_params=self._model_params, gpu_ids=([0] if self._using_cuda else None))
            self._tuner.min_perplexity_decrement = -1.
            self._tuner.set_log_level(logging.NOTSET)

        self._tuner.min_epochs = self._tuner.max_epochs = epochs

        # Convert words to indexes [suggestions]
        tuning_src_batch, tuning_trg_batch = [], []

        for source, target in zip(src_batch, trg_batch):
            tuning_src_batch.append(self._src_dict.convertToIdx(source, Constants.UNK_WORD))
            tuning_trg_batch.append(self._trg_dict.convertToIdx(target, Constants.UNK_WORD,
                                                                Constants.BOS_WORD, Constants.EOS_WORD))

        # Prepare data for training on the tuningBatch
        tuning_dataset = Dataset(tuning_src_batch, tuning_trg_batch, 32, self._using_cuda)

        self._tuner.train_model(tuning_dataset, save_epochs=0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号