def tf_get_nmt_predictor(args, nmt_path, nmt_config):
"""Get the TensorFlow NMT predictor.
Args:
args (object): SGNMT arguments from ``ArgumentParser``
nmt_config (string): NMT configuration
path (string): Path to NMT model or directory
Returns:
Predictor. An instance of ``TensorFlowNMTPredictor``
"""
if not TENSORFLOW_AVAILABLE:
logging.fatal("Could not find TensorFlow!")
return None
logging.info("Loading tensorflow nmt predictor")
if os.path.isdir(nmt_path):
nmt_config['train_dir'] = nmt_path
elif os.path.isfile(nmt_path):
nmt_config['model_path'] = nmt_path
global session
if not session:
session = tf.Session()
return TensorFlowNMTPredictor(args.cache_nmt_posteriors, nmt_config, session)
评论列表
文章目录