def make_lm_hook(d, seed_texts=None, max_seq_len=25, gpu=False,
method='sample', temperature=1, width=5,
early_stopping=None, validate=True):
"""
Make a generator hook for a normal language model
"""
def hook(trainer, epoch, batch_num, checkpoint):
trainer.log("info", "Checking training...")
if validate:
loss = sum(trainer.validate_model().pack())
trainer.log("info", "Valid loss: {:g}".format(loss))
trainer.log("info", "Registering early stopping loss...")
if early_stopping is not None:
early_stopping.add_checkpoint(loss)
trainer.log("info", "Generating text...")
scores, hyps = trainer.model.generate(
d, seed_texts=seed_texts, max_seq_len=max_seq_len, gpu=gpu,
method=method, temperature=temperature, width=width)
hyps = [format_hyp(score, hyp, hyp_num + 1, d)
for hyp_num, (score, hyp) in enumerate(zip(scores, hyps))]
trainer.log("info", '\n***' + ''.join(hyps) + "\n***")
return hook
评论列表
文章目录