def create_trainer(
config: TrainConfig,
project_path: str,
updater,
model: typing.Dict,
eval_func,
iterator_test,
iterator_train_eval,
loss_names,
converter=chainer.dataset.convert.concat_examples,
log_name='log.txt',
):
trainer = chainer.training.Trainer(updater, out=project_path)
log_trigger = (config.log_iteration, 'iteration')
save_trigger = (config.save_iteration, 'iteration')
eval_test_name = 'eval/test'
eval_train_name = 'eval/train'
snapshot = extensions.snapshot_object(model['encoder'], 'encoder{.updater.iteration}.model')
trainer.extend(snapshot, trigger=save_trigger)
snapshot = extensions.snapshot_object(model['generator'], 'generator{.updater.iteration}.model')
trainer.extend(snapshot, trigger=save_trigger)
snapshot = extensions.snapshot_object(model['mismatch_discriminator'], 'mismatch_discriminator{.updater.iteration}.model')
trainer.extend(snapshot, trigger=save_trigger)
trainer.extend(utility.chainer.dump_graph([
'encoder/' + loss_names[0],
'generator/' + loss_names[0],
'mismatch_discriminator/' + loss_names[0],
], out_name='main.dot'))
def _make_evaluator(iterator):
return utility.chainer.NoVariableEvaluator(
iterator,
target=model,
converter=converter,
eval_func=eval_func,
device=config.gpu,
)
trainer.extend(_make_evaluator(iterator_test), name=eval_test_name, trigger=log_trigger)
trainer.extend(_make_evaluator(iterator_train_eval), name=eval_train_name, trigger=log_trigger)
report_target = []
for evaluator_name in ['', eval_test_name + '/', eval_train_name + '/']:
for model_name in [s + '/' for s in model.keys()]:
for loss_name in set(loss_names):
report_target.append(evaluator_name + model_name + loss_name)
trainer.extend(extensions.LogReport(trigger=log_trigger, log_name=log_name))
trainer.extend(extensions.PrintReport(report_target))
return trainer
评论列表
文章目录