train_classifier.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:face-classifier-cnn 作者: nknytk 项目源码 文件源码
def main(config_file):
    with open(config_file) as fp:
        conf = json.load(fp)
    fe_conf = conf['feature_extractor']
    cl_conf = conf['classifier']

    fe_class = getattr(cnn_feature_extractors, fe_conf['model'])
    feature_extractor = fe_class(n_classes=fe_conf['n_classes'], n_base_units=fe_conf['n_base_units'])
    chainer.serializers.load_npz(fe_conf['out_file'], feature_extractor)

    model = classifiers.MLPClassifier(cl_conf['n_classes'], feature_extractor)
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    device = cl_conf.get('device', -1)
    train_dataset = feature_dataset(os.path.join(cl_conf['dataset_path'], 'train'), model)
    train_iter = chainer.iterators.SerialIterator(train_dataset, conf.get('batch_size', 1))
    updater = chainer.training.StandardUpdater(train_iter, optimizer, device=device)
    trainer = chainer.training.Trainer(updater, (cl_conf['epoch'], 'epoch'), out='out_re')

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar(update_interval=10))

    test_dataset_path = os.path.join(cl_conf['dataset_path'], 'test')
    if os.path.exists(test_dataset_path):
        test_dataset = feature_dataset(test_dataset_path, model)
        test_iter = chainer.iterators.SerialIterator(test_dataset, 10, repeat=False, shuffle=False)
        trainer.extend(extensions.Evaluator(test_iter, model, device=device))
        trainer.extend(extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss',
            'main/accuracy', 'validation/main/accuracy'
        ]))
    else:
        trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy']))

    trainer.run()

    chainer.serializers.save_npz(cl_conf['out_file'], model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号