def train_task(args, train_name, model, epoch_num,
train_dataset, test_dataset_dict, batch_size):
optimizer = optimizers.SGD()
optimizer.setup(model)
train_iter = iterators.SerialIterator(train_dataset, batch_size)
test_iter_dict = {name: iterators.SerialIterator(
test_dataset, batch_size, repeat=False, shuffle=False)
for name, test_dataset in test_dataset_dict.items()}
updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (epoch_num, 'epoch'), out=args.out)
for name, test_iter in test_iter_dict.items():
trainer.extend(extensions.Evaluator(test_iter, model), name)
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss'] +
[test+'/main/loss' for test in test_dataset_dict.keys()] +
['main/accuracy'] +
[test+'/main/accuracy' for test in test_dataset_dict.keys()]))
trainer.extend(extensions.ProgressBar())
trainer.extend(extensions.PlotReport(
[test+"/main/accuracy" for test
in test_dataset_dict.keys()],
file_name=train_name+".png"))
trainer.run()
评论列表
文章目录