def create_trainer(
config,
project_path,
updater,
model,
eval_func,
iterator_test,
iterator_train_varidation,
loss_names,
converter=chainer.dataset.convert.concat_examples,
):
# type: (TrainConfig, str, any, typing.Dict, any, any, any, any, any) -> any
def _make_evaluator(iterator):
return utility.chainer_utility.NoVariableEvaluator(
iterator,
target=model,
converter=converter,
eval_func=eval_func,
device=config.gpu,
)
trainer = chainer.training.Trainer(updater, out=project_path)
log_trigger = (config.log_iteration, 'iteration')
save_trigger = (config.save_iteration, 'iteration')
eval_test_name = 'eval/test'
eval_train_name = 'eval/train'
snapshot = extensions.snapshot_object(model['main'], '{.updater.iteration}.model')
trainer.extend(snapshot, trigger=save_trigger)
trainer.extend(extensions.dump_graph('main/' + loss_names[0], out_name='main.dot'))
trainer.extend(_make_evaluator(iterator_test), name=eval_test_name, trigger=log_trigger)
trainer.extend(_make_evaluator(iterator_train_varidation), name=eval_train_name, trigger=log_trigger)
report_target = []
for evaluator_name in ['', eval_test_name + '/', eval_train_name + '/']:
for model_name in ['main/']:
for loss_name in loss_names:
report_target.append(evaluator_name + model_name + loss_name)
trainer.extend(extensions.LogReport(trigger=log_trigger, log_name="log.txt"))
trainer.extend(extensions.PrintReport(report_target))
return trainer
评论列表
文章目录