chainer_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:char-rnn-text-generation 作者: yxtay 项目源码 文件源码
def train_main(args):
    """
    trains model specfied in args.
    main method for train subcommand.
    """
    # load text
    with open(args.text_path) as f:
        text = f.read()
    logger.info("corpus length: %s.", len(text))

    # data iterator
    data_iter = DataIterator(text, args.batch_size, args.seq_len)

    # load or build model
    if args.restore:
        logger.info("restoring model.")
        load_path = args.checkpoint_path if args.restore is True else args.restore
        model = load_model(load_path)
    else:
        net = Network(vocab_size=VOCAB_SIZE,
                      embedding_size=args.embedding_size,
                      rnn_size=args.rnn_size,
                      num_layers=args.num_layers,
                      drop_rate=args.drop_rate)
        model = L.Classifier(net)

    # make checkpoint directory
    log_dir = make_dirs(args.checkpoint_path)
    with open("{}.json".format(args.checkpoint_path), "w") as f:
        json.dump(model.predictor.args, f, indent=2)
    chainer.serializers.save_npz(args.checkpoint_path, model)
    logger.info("model saved: %s.", args.checkpoint_path)

    # optimizer
    optimizer = chainer.optimizers.Adam(alpha=args.learning_rate)
    optimizer.setup(model)
    # clip gradient norm
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.clip_norm))

    # trainer
    updater = BpttUpdater(data_iter, optimizer)
    trainer = chainer.training.Trainer(updater, (args.num_epochs, 'epoch'), out=log_dir)
    trainer.extend(extensions.snapshot_object(model, filename=os.path.basename(args.checkpoint_path)))
    trainer.extend(extensions.ProgressBar(update_interval=1))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PlotReport(y_keys=["main/loss"]))
    trainer.extend(LoggerExtension(text))

    # training start
    model.predictor.reset_state()
    logger.info("start of training.")
    time_train = time.time()
    trainer.run()

    # training end
    duration_train = time.time() - time_train
    logger.info("end of training, duration: %ds.", duration_train)
    # generate text
    seed = generate_seed(text)
    generate_text(model, seed, 1024, 3)
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号