def main(framework, train_main, generate_main):
arg_parser = ArgumentParser(
description="{} character embeddings LSTM text generation model.".format(framework))
subparsers = arg_parser.add_subparsers(title="subcommands")
# train args
train_parser = subparsers.add_parser("train", help="train model on text file")
train_parser.add_argument("--checkpoint-path", required=True,
help="path to save or load model checkpoints (required)")
train_parser.add_argument("--text-path", required=True,
help="path of text file for training (required)")
train_parser.add_argument("--restore", nargs="?", default=False, const=True,
help="whether to restore from checkpoint_path "
"or from another path if specified")
train_parser.add_argument("--seq-len", type=int, default=64,
help="sequence length of inputs and outputs (default: %(default)s)")
train_parser.add_argument("--embedding-size", type=int, default=32,
help="character embedding size (default: %(default)s)")
train_parser.add_argument("--rnn-size", type=int, default=128,
help="size of rnn cell (default: %(default)s)")
train_parser.add_argument("--num-layers", type=int, default=2,
help="number of rnn layers (default: %(default)s)")
train_parser.add_argument("--drop-rate", type=float, default=0.,
help="dropout rate for rnn layers (default: %(default)s)")
train_parser.add_argument("--learning-rate", type=float, default=0.001,
help="learning rate (default: %(default)s)")
train_parser.add_argument("--clip-norm", type=float, default=5.,
help="max norm to clip gradient (default: %(default)s)")
train_parser.add_argument("--batch-size", type=int, default=64,
help="training batch size (default: %(default)s)")
train_parser.add_argument("--num-epochs", type=int, default=32,
help="number of epochs for training (default: %(default)s)")
train_parser.add_argument("--log-path", default=os.path.join(os.path.dirname(__file__), "main.log"),
help="path of log file (default: %(default)s)")
train_parser.set_defaults(main=train_main)
# generate args
generate_parser = subparsers.add_parser("generate", help="generate text from trained model")
generate_parser.add_argument("--checkpoint-path", required=True,
help="path to load model checkpoints (required)")
group = generate_parser.add_mutually_exclusive_group(required=True)
group.add_argument("--text-path", help="path of text file to generate seed")
group.add_argument("--seed", default=None, help="seed character sequence")
generate_parser.add_argument("--length", type=int, default=1024,
help="length of character sequence to generate (default: %(default)s)")
generate_parser.add_argument("--top-n", type=int, default=3,
help="number of top choices to sample (default: %(default)s)")
generate_parser.add_argument("--log-path", default=os.path.join(os.path.dirname(__file__), "main.log"),
help="path of log file (default: %(default)s)")
generate_parser.set_defaults(main=generate_main)
args = arg_parser.parse_args()
get_logger("__main__", log_path=args.log_path, console=True)
logger = get_logger(__name__, log_path=args.log_path, console=True)
logger.debug("call: %s", " ".join(sys.argv))
logger.debug("ArgumentParser: %s", args)
try:
args.main(args)
except Exception as e:
logger.exception(e)
评论列表
文章目录