main.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:dynamic-coattention-network 作者: marshmelloX 项目源码 文件源码
def main(unused_argv):
  vocab = dataset.Vocab(FLAGS.vocab_path, 200000)
  # Check for presence of required special tokens.
  assert vocab.tokenToId(dataset.PAD_TOKEN) > 0
  assert vocab.tokenToId(dataset.UNKNOWN_TOKEN) > 0
  assert vocab.tokenToId(dataset.SENTENCE_START) > 0
  assert vocab.tokenToId(dataset.SENTENCE_END) > 0
  assert vocab.tokenToId(dataset.WORD_BEGIN) > 0
  assert vocab.tokenToId(dataset.WORD_CONTINUE) > 0
  assert vocab.tokenToId(dataset.WORD_END) > 0

  params = selector.parameters(
      mode=FLAGS.mode,  # train, eval, decode
      min_lr=0.01,  # min learning rate.
      lr=0.1,  # learning rate
      batch_size=1,
      c_timesteps=600, # context length
      q_timesteps=30, # question length
      min_input_len=2,  # discard context, question < than this words
      hidden_size=200,  # for rnn cell and embedding
      emb_size=200,  # If 0, don't use embedding
      max_decode_steps=4,
      maxout_size=32,
      max_grad_norm=2)

  batcher = batch_reader.Generator(
      FLAGS.data_path, vocab, params,
      FLAGS.context_key, FLAGS.question_key, FLAGS.answer_key,
      FLAGS.max_context_sentences, FLAGS.max_question_sentences,
      bucketing=FLAGS.use_bucketing, truncate_input=FLAGS.truncate_input)

  tf.set_random_seed(FLAGS.random_seed)

  if params.mode == 'train':
    model = selector.Model(
        params, len(vocab), num_cpus=FLAGS.num_cpus, num_gpus=FLAGS.num_gpus)
    _train(model, batcher)
  elif params.mode == 'eval':
    model = selector.Model(
        params, len(vocab), num_cpus=FLAGS.num_cpus, num_gpus=FLAGS.num_gpus)
    _eval(model, batcher)
  elif params.mode == 'decode':
    model = selector.Model(
        params, len(vocab), num_cpus=FLAGS.num_cpus, num_gpus=FLAGS.num_gpus)
    machine = decoder.Decoder(model, batcher, params, vocab)
    machine.loop()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号