interface.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:sgnmt 作者: ucam-smt 项目源码 文件源码
def tf_get_nmt_predictor(args, nmt_path, nmt_config):
  """Get the TensorFlow NMT predictor.

  Args:
    args (object): SGNMT arguments from ``ArgumentParser``
    nmt_config (string): NMT configuration
    path (string): Path to NMT model or directory

  Returns:
    Predictor. An instance of ``TensorFlowNMTPredictor``
  """
  if not TENSORFLOW_AVAILABLE:
    logging.fatal("Could not find TensorFlow!")
    return None

  logging.info("Loading tensorflow nmt predictor")
  if os.path.isdir(nmt_path):
    nmt_config['train_dir'] = nmt_path
  elif os.path.isfile(nmt_path):
    nmt_config['model_path'] = nmt_path
  global session
  if not session:
    session = tf.Session()
  return TensorFlowNMTPredictor(args.cache_nmt_posteriors, nmt_config, session)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号