def main(unused_argv):
vocab = dataset.Vocab(FLAGS.vocab_path, 200000)
# Check for presence of required special tokens.
assert vocab.tokenToId(dataset.PAD_TOKEN) > 0
assert vocab.tokenToId(dataset.UNKNOWN_TOKEN) > 0
assert vocab.tokenToId(dataset.SENTENCE_START) > 0
assert vocab.tokenToId(dataset.SENTENCE_END) > 0
assert vocab.tokenToId(dataset.WORD_BEGIN) > 0
assert vocab.tokenToId(dataset.WORD_CONTINUE) > 0
assert vocab.tokenToId(dataset.WORD_END) > 0
params = selector.parameters(
mode=FLAGS.mode, # train, eval, decode
min_lr=0.01, # min learning rate.
lr=0.1, # learning rate
batch_size=1,
c_timesteps=600, # context length
q_timesteps=30, # question length
min_input_len=2, # discard context, question < than this words
hidden_size=200, # for rnn cell and embedding
emb_size=200, # If 0, don't use embedding
max_decode_steps=4,
maxout_size=32,
max_grad_norm=2)
batcher = batch_reader.Generator(
FLAGS.data_path, vocab, params,
FLAGS.context_key, FLAGS.question_key, FLAGS.answer_key,
FLAGS.max_context_sentences, FLAGS.max_question_sentences,
bucketing=FLAGS.use_bucketing, truncate_input=FLAGS.truncate_input)
tf.set_random_seed(FLAGS.random_seed)
if params.mode == 'train':
model = selector.Model(
params, len(vocab), num_cpus=FLAGS.num_cpus, num_gpus=FLAGS.num_gpus)
_train(model, batcher)
elif params.mode == 'eval':
model = selector.Model(
params, len(vocab), num_cpus=FLAGS.num_cpus, num_gpus=FLAGS.num_gpus)
_eval(model, batcher)
elif params.mode == 'decode':
model = selector.Model(
params, len(vocab), num_cpus=FLAGS.num_cpus, num_gpus=FLAGS.num_gpus)
machine = decoder.Decoder(model, batcher, params, vocab)
machine.loop()
评论列表
文章目录