def check_train(self, gpu):
outdir = tempfile.mkdtemp()
print("outdir: {}".format(outdir))
n_classes = 2
batch_size = 32
devices = {'main': gpu}
A = np.array([
[0, 1, 1, 0],
[1, 0, 0, 1],
[1, 0, 0, 0],
[0, 1, 0, 0],
]).astype(np.float32)
model = graph_cnn.GraphCNN(A, n_out=n_classes)
optimizer = optimizers.Adam(alpha=1e-4)
optimizer.setup(model)
train_dataset = EasyDataset(train=True, n_classes=n_classes)
train_iter = chainer.iterators.MultiprocessIterator(
train_dataset, batch_size)
updater = ParallelUpdater(train_iter, optimizer, devices=devices)
trainer = chainer.training.Trainer(updater, (10, 'epoch'), out=outdir)
trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
trainer.extend(extensions.PrintReport(
['epoch', 'iteration', 'main/loss', 'main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
test_graph_cnn.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录