def setup_mnist_trainer(self, display_log=False):
batchsize = 100
n_units = 100
comm = self.communicator
model = L.Classifier(MLP(n_units, 10))
optimizer = chainermn.create_multi_node_optimizer(
chainer.optimizers.Adam(), comm)
optimizer.setup(model)
if comm.rank == 0:
train, test = chainer.datasets.get_mnist()
else:
train, test = None, None
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm, shuffle=True)
train_iter = chainer.iterators.SerialIterator(train, batchsize)
test_iter = chainer.iterators.SerialIterator(test, batchsize,
repeat=False,
shuffle=False)
updater = training.StandardUpdater(
train_iter,
optimizer
)
return updater, optimizer, train_iter, test_iter, model
评论列表
文章目录