parse_nn_swbd.py 文件源码

python
阅读 13 收藏 0 点赞 0 评论 0

项目:seq2seq_parser 作者: trangham283 项目源码 文件源码
def decode():
  """ Decode file sentence-by-sentence  """
  with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=NUM_THREADS)) as sess:
    # Create model and load parameters.
    with tf.variable_scope("model"):
      model, steps_done = create_model(sess, True, attention=FLAGS.attention, model_path=FLAGS.model_path)
    model.batch_size = 1  # We decode one sentence at a time.

    # Load vocabularies.
    sents_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.sents" % FLAGS.input_vocab_size)
    parse_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.parse" % FLAGS.output_vocab_size)
    sents_vocab, _ = data_utils.initialize_vocabulary(sents_vocab_path)
    _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path)

    start_time = time.time()
    # Decode 
    with open(FLAGS.decode_input_path, 'r') as fin, open(FLAGS.decode_output_path, 'w') as fout:
      for line in fin:
        sentence = line.strip()
        token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), sents_vocab)
        try:
          bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)])
        except:
          print("Input sentence does not fit in any buckets. Skipping... ")
          print("\t", line)
          continue
        encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id)
        _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, True)
        outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
        if data_utils.EOS_ID in outputs:
          outputs = outputs[:outputs.index(data_utils.EOS_ID)]
        decoded_sentence = " ".join([tf.compat.as_str(rev_parse_vocab[output]) for output in outputs]) + '\n'
        fout.write(decoded_sentence)
    time_elapsed = time.time() - start_time
    print("Decoding time: ", time_elapsed)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号