def main():
parser = argparse.ArgumentParser(description='selfplaying script')
parser.add_argument('--alice_model_file', type=str,
help='Alice model file')
parser.add_argument('--bob_model_file', type=str,
help='Bob model file')
parser.add_argument('--context_file', type=str,
help='context file')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature')
parser.add_argument('--verbose', action='store_true', default=False,
help='print out converations')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--score_threshold', type=int, default=6,
help='successful dialog should have more than score_threshold in score')
parser.add_argument('--max_turns', type=int, default=20,
help='maximum number of turns in a dialog')
parser.add_argument('--log_file', type=str, default='',
help='log successful dialogs to file for training')
parser.add_argument('--smart_alice', action='store_true', default=False,
help='make Alice smart again')
parser.add_argument('--fast_rollout', action='store_true', default=False,
help='to use faster rollouts')
parser.add_argument('--rollout_bsz', type=int, default=100,
help='rollout batch size')
parser.add_argument('--rollout_count_threshold', type=int, default=3,
help='rollout count threshold')
parser.add_argument('--smart_bob', action='store_true', default=False,
help='make Bob smart again')
parser.add_argument('--ref_text', type=str,
help='file with the reference text')
parser.add_argument('--domain', type=str, default='object_division',
help='domain for the dialogue')
args = parser.parse_args()
utils.set_seed(args.seed)
alice_model = utils.load_model(args.alice_model_file)
alice_ty = get_agent_type(alice_model, args.smart_alice, args.fast_rollout)
alice = alice_ty(alice_model, args, name='Alice')
bob_model = utils.load_model(args.bob_model_file)
bob_ty = get_agent_type(bob_model, args.smart_bob, args.fast_rollout)
bob = bob_ty(bob_model, args, name='Bob')
dialog = Dialog([alice, bob], args)
logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
ctx_gen = ContextGenerator(args.context_file)
selfplay = SelfPlay(dialog, ctx_gen, args, logger)
selfplay.run()
selfplay.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录