def main(config_file):
with open(config_file) as fp:
conf = json.load(fp)
fe_conf = conf['feature_extractor']
cl_conf = conf['classifier']
fe_class = getattr(cnn_feature_extractors, fe_conf['model'])
feature_extractor = fe_class(n_classes=fe_conf['n_classes'], n_base_units=fe_conf['n_base_units'])
chainer.serializers.load_npz(fe_conf['out_file'], feature_extractor)
model = classifiers.MLPClassifier(cl_conf['n_classes'], feature_extractor)
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
device = cl_conf.get('device', -1)
train_dataset = feature_dataset(os.path.join(cl_conf['dataset_path'], 'train'), model)
train_iter = chainer.iterators.SerialIterator(train_dataset, conf.get('batch_size', 1))
updater = chainer.training.StandardUpdater(train_iter, optimizer, device=device)
trainer = chainer.training.Trainer(updater, (cl_conf['epoch'], 'epoch'), out='out_re')
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.ProgressBar(update_interval=10))
test_dataset_path = os.path.join(cl_conf['dataset_path'], 'test')
if os.path.exists(test_dataset_path):
test_dataset = feature_dataset(test_dataset_path, model)
test_iter = chainer.iterators.SerialIterator(test_dataset, 10, repeat=False, shuffle=False)
trainer.extend(extensions.Evaluator(test_iter, model, device=device))
trainer.extend(extensions.PrintReport([
'epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy'
]))
else:
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy']))
trainer.run()
chainer.serializers.save_npz(cl_conf['out_file'], model)
评论列表
文章目录