transformer_model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def __init__(self, data_dir, model_dir):
    """Creates the Transformer estimator.

    Args:
      data_dir: The training data directory.
      model_dir: The trained model directory.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    FLAGS.output_dir = model_dir
    FLAGS.data_dir = data_dir
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(data_dir)

    # Create the basic hyper parameters.
    self.hparams = tpu_trainer_lib.create_hparams(
        FLAGS.hparams_set,
        FLAGS.hparams,
        data_dir=data_dir,
        problem_name=FLAGS.problems)

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = tpu_trainer_lib.create_estimator(
        FLAGS.model,
        self.hparams,
        tpu_trainer.create_run_config(),
        decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problems[0].vocabulary["inputs"]
    self.targets_vocab = self.hparams.problems[0].vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号