ewc_mnist.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:chainer-EWC 作者: okdshin 项目源码 文件源码
def train_task(args, train_name, model, epoch_num,
               train_dataset, test_dataset_dict, batch_size):
    optimizer = optimizers.SGD()
    optimizer.setup(model)

    train_iter = iterators.SerialIterator(train_dataset, batch_size)
    test_iter_dict = {name: iterators.SerialIterator(
            test_dataset, batch_size, repeat=False, shuffle=False)
            for name, test_dataset in test_dataset_dict.items()}

    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (epoch_num, 'epoch'), out=args.out)
    for name, test_iter in test_iter_dict.items():
        trainer.extend(extensions.Evaluator(test_iter, model), name)
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss'] +
        [test+'/main/loss' for test in test_dataset_dict.keys()] +
        ['main/accuracy'] +
        [test+'/main/accuracy' for test in test_dataset_dict.keys()]))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.PlotReport(
        [test+"/main/accuracy" for test
         in test_dataset_dict.keys()],
        file_name=train_name+".png"))
    trainer.run()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号