def get_trainer(updater, evaluator, epochs):
trainer = training.Trainer(updater, (epochs, 'epoch'), out='result')
trainer.extend(evaluator)
# TODO: reduce LR -- how to update every X epochs?
# trainer.extend(extensions.ExponentialShift('lr', 0.1, target=lr*0.0001))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.ProgressBar(
(epochs, 'epoch'), update_interval=10))
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss']))
return trainer
评论列表
文章目录