trainer.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:paint_transfer_c92 作者: Hiroshiba 项目源码 文件源码
def create_trainer(
        config: TrainConfig,
        project_path: str,
        updater,
        model: typing.Dict,
        eval_func,
        iterator_test,
        iterator_train_eval,
        loss_names,
        converter=chainer.dataset.convert.concat_examples,
        log_name='log.txt',
):
    trainer = chainer.training.Trainer(updater, out=project_path)

    log_trigger = (config.log_iteration, 'iteration')
    save_trigger = (config.save_iteration, 'iteration')

    eval_test_name = 'eval/test'
    eval_train_name = 'eval/train'

    snapshot = extensions.snapshot_object(model['encoder'], 'encoder{.updater.iteration}.model')
    trainer.extend(snapshot, trigger=save_trigger)
    snapshot = extensions.snapshot_object(model['generator'], 'generator{.updater.iteration}.model')
    trainer.extend(snapshot, trigger=save_trigger)
    snapshot = extensions.snapshot_object(model['mismatch_discriminator'], 'mismatch_discriminator{.updater.iteration}.model')
    trainer.extend(snapshot, trigger=save_trigger)

    trainer.extend(utility.chainer.dump_graph([
        'encoder/' + loss_names[0],
        'generator/' + loss_names[0],
        'mismatch_discriminator/' + loss_names[0],
    ], out_name='main.dot'))

    def _make_evaluator(iterator):
        return utility.chainer.NoVariableEvaluator(
            iterator,
            target=model,
            converter=converter,
            eval_func=eval_func,
            device=config.gpu,
        )

    trainer.extend(_make_evaluator(iterator_test), name=eval_test_name, trigger=log_trigger)
    trainer.extend(_make_evaluator(iterator_train_eval), name=eval_train_name, trigger=log_trigger)

    report_target = []
    for evaluator_name in ['', eval_test_name + '/', eval_train_name + '/']:
        for model_name in [s + '/' for s in model.keys()]:
            for loss_name in set(loss_names):
                report_target.append(evaluator_name + model_name + loss_name)

    trainer.extend(extensions.LogReport(trigger=log_trigger, log_name=log_name))
    trainer.extend(extensions.PrintReport(report_target))

    return trainer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号