def __init__(self, data_dir, model_dir):
"""Creates the Transformer estimator.
Args:
data_dir: The training data directory.
model_dir: The trained model directory.
"""
# Do the pre-setup tensor2tensor requires for flags and configurations.
FLAGS.output_dir = model_dir
FLAGS.data_dir = data_dir
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
data_dir = os.path.expanduser(data_dir)
# Create the basic hyper parameters.
self.hparams = tpu_trainer_lib.create_hparams(
FLAGS.hparams_set,
FLAGS.hparams,
data_dir=data_dir,
problem_name=FLAGS.problems)
decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
decode_hp.add_hparam("shards", 1)
decode_hp.add_hparam("shard_id", 0)
# Create the estimator and final hyper parameters.
self.estimator = tpu_trainer_lib.create_estimator(
FLAGS.model,
self.hparams,
tpu_trainer.create_run_config(),
decode_hp, use_tpu=False)
# Fetch the vocabulary and other helpful variables for decoding.
self.source_vocab = self.hparams.problems[0].vocabulary["inputs"]
self.targets_vocab = self.hparams.problems[0].vocabulary["targets"]
self.const_array_size = 10000
# Prepare the Transformer's debug data directory.
run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
for run_dir in run_dirs:
shutil.rmtree(run_dir)
评论列表
文章目录