def train_main(args):
"""
trains model specfied in args.
main method for train subcommand.
"""
# load text
with open(args.text_path) as f:
text = f.read()
logger.info("corpus length: %s.", len(text))
# data iterator
data_iter = DataIterator(text, args.batch_size, args.seq_len)
# load or build model
if args.restore:
logger.info("restoring model.")
load_path = args.checkpoint_path if args.restore is True else args.restore
model = load_model(load_path)
else:
net = Network(vocab_size=VOCAB_SIZE,
embedding_size=args.embedding_size,
rnn_size=args.rnn_size,
num_layers=args.num_layers,
drop_rate=args.drop_rate)
model = L.Classifier(net)
# make checkpoint directory
log_dir = make_dirs(args.checkpoint_path)
with open("{}.json".format(args.checkpoint_path), "w") as f:
json.dump(model.predictor.args, f, indent=2)
chainer.serializers.save_npz(args.checkpoint_path, model)
logger.info("model saved: %s.", args.checkpoint_path)
# optimizer
optimizer = chainer.optimizers.Adam(alpha=args.learning_rate)
optimizer.setup(model)
# clip gradient norm
optimizer.add_hook(chainer.optimizer.GradientClipping(args.clip_norm))
# trainer
updater = BpttUpdater(data_iter, optimizer)
trainer = chainer.training.Trainer(updater, (args.num_epochs, 'epoch'), out=log_dir)
trainer.extend(extensions.snapshot_object(model, filename=os.path.basename(args.checkpoint_path)))
trainer.extend(extensions.ProgressBar(update_interval=1))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PlotReport(y_keys=["main/loss"]))
trainer.extend(LoggerExtension(text))
# training start
model.predictor.reset_state()
logger.info("start of training.")
time_train = time.time()
trainer.run()
# training end
duration_train = time.time() - time_train
logger.info("end of training, duration: %ds.", duration_train)
# generate text
seed = generate_seed(text)
generate_text(model, seed, 1024, 3)
return model
评论列表
文章目录