train_feature_extractor.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:face-classifier-cnn 作者: nknytk 项目源码 文件源码
def main(config_file):
    with open(config_file) as fp:
        conf = json.load(fp)['feature_extractor']

    model_class = getattr(cnn_feature_extractors, conf['model'])
    model = model_class(conf['n_classes'], conf['n_base_units'])

    resume_file = conf['out_file'] + '.to_resume'
    if os.path.exists(resume_file):
        chainer.serializers.load_npz(resume_file, model)
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    device = conf.get('device', -1)
    train_dataset = create_dataset(os.path.join(conf['dataset_path'], 'train'))
    train_iter = chainer.iterators.SerialIterator(train_dataset, conf.get('batch_size', 10))
    updater = chainer.training.StandardUpdater(train_iter, optimizer, device=device)
    trainer = chainer.training.Trainer(updater, (conf['epoch'], 'epoch'), out='out')

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar(update_interval=10))

    test_dataset_path = os.path.join(conf['dataset_path'], 'test')
    if os.path.exists(test_dataset_path):
        test_dataset = create_dataset(test_dataset_path)
        test_iter = chainer.iterators.SerialIterator(test_dataset, 20, repeat=False, shuffle=False)
        trainer.extend(extensions.Evaluator(test_iter, model, device=device))
        trainer.extend(extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy'
        ]))
    else:
        trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy']))

    trainer.run()

    model = model.to_cpu()
    chainer.serializers.save_npz(conf['out_file'], model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号