def __init__(self, model_path, gpu_id=None, random_seed=None):
self._logger = logging.getLogger('nmmt.NMTDecoder')
if gpu_id is not None:
torch.cuda.set_device(gpu_id)
if random_seed is not None:
torch.manual_seed(random_seed)
random.manual_seed_all(random_seed)
using_cuda = gpu_id is not None
self._text_processor = SubwordTextProcessor.load_from_file(os.path.join(model_path, 'model.bpe'))
with log_timed_action(self._logger, 'Loading model from checkpoint'):
self._engine = NMTEngine.load_from_checkpoint(os.path.join(model_path, 'model.pt'), using_cuda=using_cuda)
# Public-editable options
self.beam_size = 5
self.max_sent_length = 160
self.replace_unk = False
self.tuning_epochs = 5
评论列表
文章目录