test_checkpoint.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def setup_mnist_trainer(self, display_log=False):
        batchsize = 100
        n_units = 100

        comm = self.communicator
        model = L.Classifier(MLP(n_units, 10))

        optimizer = chainermn.create_multi_node_optimizer(
            chainer.optimizers.Adam(), comm)
        optimizer.setup(model)

        if comm.rank == 0:
            train, test = chainer.datasets.get_mnist()
        else:
            train, test = None, None

        train = chainermn.scatter_dataset(train, comm, shuffle=True)
        test = chainermn.scatter_dataset(test, comm, shuffle=True)

        train_iter = chainer.iterators.SerialIterator(train, batchsize)
        test_iter = chainer.iterators.SerialIterator(test, batchsize,
                                                     repeat=False,
                                                     shuffle=False)

        updater = training.StandardUpdater(
            train_iter,
            optimizer
        )

        return updater, optimizer, train_iter, test_iter, model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号