env_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:squadgym 作者: aleSuglia 项目源码 文件源码
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("env_data", type=str, help="Generated environment data filename in JSON format")
    args = parser.parse_args()

    print("-- Initialized environment")
    env = SquadEnv(args.env_data)

    context, question = env.reset()
    done = False

    while not done:
        print("Context ids: {}".format(context))
        print("Question ids: {}".format(question))
        print("Context tokens: {}".format(ids2tokens(context, env.id2token)))
        print("Question tokens: {}".format(ids2tokens(question, env.id2token)))
        answer_tokens = tokens2ids(word_tokenize(input("Answer: ")) + ["#eos#"], env.token2id)

        question_reward = 0
        for token in answer_tokens:
            (context, question), reward, done, _ = env.step(token)
            question_reward += reward

        print("You got {} reward".format(question_reward))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号