NMTDecoder.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:NeuralMT 作者: hlt-mt 项目源码 文件源码
def __init__(self, model_path, gpu_id=None, random_seed=None):
        self._logger = logging.getLogger('nmmt.NMTDecoder')

        if gpu_id is not None:
            torch.cuda.set_device(gpu_id)

        if random_seed is not None:
            torch.manual_seed(random_seed)
            random.manual_seed_all(random_seed)

        using_cuda = gpu_id is not None

        self._text_processor = SubwordTextProcessor.load_from_file(os.path.join(model_path, 'model.bpe'))
        with log_timed_action(self._logger, 'Loading model from checkpoint'):
            self._engine = NMTEngine.load_from_checkpoint(os.path.join(model_path, 'model.pt'), using_cuda=using_cuda)

        # Public-editable options
        self.beam_size = 5
        self.max_sent_length = 160
        self.replace_unk = False
        self.tuning_epochs = 5
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号