mnist_cnn.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:chainer-examples 作者: nocotan 项目源码 文件源码
def main():
    model = L.Classifier(CNN())

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    train, test = chainer.datasets.get_mnist(ndim=3)
    train_iter = chainer.iterators.SerialIterator(train, batch_size=100)
    test_iter = chainer.iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)

    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (5, 'epoch'), out='result')

    trainer.extend(extensions.Evaluator(test_iter, model))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy','validation/main/accuracy']))
    trainer.extend(extensions.ProgressBar())

    trainer.run()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号