grl_train.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow 作者: liuyuemaicha 项目源码 文件源码
def test_decoder(config):
    train_path = os.path.join(config.train_dir, "chitchat.train")
    data_path_list = [train_path + ".answer", train_path + ".query"]
    vocab_path = os.path.join(config.train_dir, "vocab%d.all" % config.vocab_size)
    data_utils.create_vocabulary(vocab_path, data_path_list, config.vocab_size)
    vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

    with tf.Session() as sess:
        if config.name_model in [gst_config.name_model, gcc_config.name_model, gbk_config.name_model]:
            model = create_st_model(sess, config, forward_only=True, name_scope=config.name_model)

        elif config.name_model in [grl_config.name_model, pre_grl_config.name_model]:
            model = create_rl_model(sess, config, forward_only=True, name_scope=config.name_model)

        model.batch_size = 1

        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), vocab)
            print("token_id: ", token_ids)
            bucket_id = len(config.buckets) - 1
            for i, bucket in enumerate(config.buckets):
                if bucket[0] >= len(token_ids):
                    bucket_id = i
                    break
            else:
                print("Sentence truncated: %s", sentence)

            encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch({bucket_id: [(token_ids, [1])]},
                                                                                   bucket_id)
            # st_model step
            if config.name_model in [gst_config.name_model, gcc_config.name_model, gbk_config.name_model]:
                output_logits, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
                outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
                if data_utils.EOS_ID in outputs:
                    outputs = outputs[:outputs.index(data_utils.EOS_ID)]
                print(" ".join([str(rev_vocab[output]) for output in outputs]))

            # beam_search step
            elif config.name_model in [grl_config.name_model, pre_grl_config.name_model]:
                _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, reward=1,
                                                 bucket_id=bucket_id, forward_only=True)
                for i, output in enumerate(output_logits):
                    print("index: %d, answer tokens: %s" %(i, str(output)))
                    if data_utils.EOS_ID in output:
                        output = output[:output.index(data_utils.EOS_ID)]
                    print(" ".join([str(rev_vocab[out]) for out in output]))

            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号