def main(config_file):
with open(config_file) as fp:
conf = json.load(fp)['feature_extractor']
model_class = getattr(cnn_feature_extractors, conf['model'])
model = model_class(conf['n_classes'], conf['n_base_units'])
resume_file = conf['out_file'] + '.to_resume'
if os.path.exists(resume_file):
chainer.serializers.load_npz(resume_file, model)
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
device = conf.get('device', -1)
train_dataset = create_dataset(os.path.join(conf['dataset_path'], 'train'))
train_iter = chainer.iterators.SerialIterator(train_dataset, conf.get('batch_size', 10))
updater = chainer.training.StandardUpdater(train_iter, optimizer, device=device)
trainer = chainer.training.Trainer(updater, (conf['epoch'], 'epoch'), out='out')
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.ProgressBar(update_interval=10))
test_dataset_path = os.path.join(conf['dataset_path'], 'test')
if os.path.exists(test_dataset_path):
test_dataset = create_dataset(test_dataset_path)
test_iter = chainer.iterators.SerialIterator(test_dataset, 20, repeat=False, shuffle=False)
trainer.extend(extensions.Evaluator(test_iter, model, device=device))
trainer.extend(extensions.PrintReport([
'epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy'
]))
else:
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy']))
trainer.run()
model = model.to_cpu()
chainer.serializers.save_npz(conf['out_file'], model)
train_feature_extractor.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录